| @@ -52,4 +52,4 @@ jobs: | |||||
| - name: Build | - name: Build | ||||
| run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore | run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore | ||||
| - name: Test | - name: Test | ||||
| run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} | |||||
| run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed" | |||||
| @@ -0,0 +1,63 @@ | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using Xunit.Abstractions; | |||||
| namespace LLama.Unittest; | |||||
| public sealed class BeamTests | |||||
| : IDisposable | |||||
| { | |||||
| private readonly ITestOutputHelper _testOutputHelper; | |||||
| private readonly ModelParams _params; | |||||
| private readonly LLamaWeights _model; | |||||
| public BeamTests(ITestOutputHelper testOutputHelper) | |||||
| { | |||||
| _testOutputHelper = testOutputHelper; | |||||
| _params = new ModelParams(Constants.ModelPath) | |||||
| { | |||||
| ContextSize = 2048 | |||||
| }; | |||||
| _model = LLamaWeights.LoadFromFile(_params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact(Skip = "Very very slow in CI")] | |||||
| public void BasicBeam() | |||||
| { | |||||
| const int num_beams = 2; | |||||
| const int n_predict = 3; | |||||
| var context = _model.CreateContext(_params); | |||||
| var result = new StringBuilder(); | |||||
| var initial_tokens = context.Tokenize("The cat sat on"); | |||||
| result.Append(context.DeTokenize(initial_tokens.ToArray())); | |||||
| context.Eval(initial_tokens, 0); | |||||
| NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | |||||
| { | |||||
| for (var i = 0; i < state.Beams.Length; i++) | |||||
| { | |||||
| ref var view = ref state.Beams[i]; | |||||
| var tokens = context.DeTokenize(view.Tokens.ToArray()); | |||||
| _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | |||||
| } | |||||
| if (state.CommonPrefixLength > 0) | |||||
| { | |||||
| var view = state.Beams[0]; | |||||
| result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); | |||||
| } | |||||
| }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | |||||
| _testOutputHelper.WriteLine($"Final: {result}"); | |||||
| } | |||||
| } | |||||
| @@ -66,7 +66,7 @@ namespace LLama.Unittest | |||||
| Grammar = grammar, | Grammar = grammar, | ||||
| }; | }; | ||||
| var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); | |||||
| var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); | |||||
| Assert.Equal("cat", result[0]); | Assert.Equal("cat", result[0]); | ||||
| } | } | ||||
| @@ -0,0 +1,42 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native; | |||||
| using llama_token = Int32; | |||||
| /// <summary> | |||||
| /// Information about a single beam in a beam search | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public struct LLamaBeamView | |||||
| { | |||||
| private readonly unsafe llama_token* tokens; | |||||
| private readonly nint n_tokens; | |||||
| /// <summary> | |||||
| /// Cumulative beam probability (renormalized relative to all beams) | |||||
| /// </summary> | |||||
| public readonly float CumulativeProbability; | |||||
| /// <summary> | |||||
| /// Callback should set this to true when a beam is at end-of-beam. | |||||
| /// </summary> | |||||
| public bool EndOfBeam; | |||||
| /// <summary> | |||||
| /// Tokens in this beam | |||||
| /// </summary> | |||||
| public readonly Span<llama_token> Tokens | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (n_tokens > int.MaxValue) | |||||
| throw new InvalidOperationException("More than 2147483647 tokens is not supported"); | |||||
| return new Span<llama_token>(tokens, (int)n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,49 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// Passed to beam_search_callback function. | |||||
| /// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams | |||||
| /// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public readonly struct LLamaBeamsState | |||||
| { | |||||
| /// <summary> | |||||
| /// The state of each individual beam | |||||
| /// </summary> | |||||
| private readonly unsafe LLamaBeamView* beam_views; | |||||
| /// <summary> | |||||
| /// Number of elements in beam_views | |||||
| /// </summary> | |||||
| private readonly nint n_beams; | |||||
| /// <summary> | |||||
| /// Current max length of prefix tokens shared by all beams. | |||||
| /// </summary> | |||||
| public readonly ulong CommonPrefixLength; | |||||
| /// <summary> | |||||
| /// True iff this is the last callback invocation. | |||||
| /// </summary> | |||||
| public readonly bool LastCall; | |||||
| /// <summary> | |||||
| /// The current state of each beam | |||||
| /// </summary> | |||||
| public Span<LLamaBeamView> Beams | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (n_beams > int.MaxValue) | |||||
| throw new InvalidOperationException("More than 2147483647 beams is not supported"); | |||||
| return new Span<LLamaBeamView>(beam_views, (int)n_beams); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native; | |||||
| public partial class NativeApi | |||||
| { | |||||
| /// <summary> | |||||
| /// Type of pointer to the beam_search_callback function. | |||||
| /// </summary> | |||||
| /// <param name="callback_data">callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac</param> | |||||
| /// <param name="state"></param> | |||||
| public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state); | |||||
| /// <summary>Deterministically returns entire sentence constructed by a beam search.</summary> | |||||
| /// <param name="ctx">Pointer to the llama_context.</param> | |||||
| /// <param name="callback">Invoked for each iteration of the beam_search loop, passing in beams_state.</param> | |||||
| /// <param name="callback_data">A pointer that is simply passed back to callback.</param> | |||||
| /// <param name="n_beams">Number of beams to use.</param> | |||||
| /// <param name="n_past">Number of tokens already evaluated.</param> | |||||
| /// <param name="n_predict">Maximum number of tokens to predict. EOS may occur earlier.</param> | |||||
| /// <param name="n_threads">Number of threads.</param> | |||||
| [DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads); | |||||
| } | |||||