| @@ -52,4 +52,4 @@ jobs: | |||
| - name: Build | |||
| run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore | |||
| - name: Test | |||
| run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} | |||
| run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed" | |||
| @@ -0,0 +1,63 @@ | |||
| using System.Text; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using Xunit.Abstractions; | |||
| namespace LLama.Unittest; | |||
| public sealed class BeamTests | |||
| : IDisposable | |||
| { | |||
| private readonly ITestOutputHelper _testOutputHelper; | |||
| private readonly ModelParams _params; | |||
| private readonly LLamaWeights _model; | |||
| public BeamTests(ITestOutputHelper testOutputHelper) | |||
| { | |||
| _testOutputHelper = testOutputHelper; | |||
| _params = new ModelParams(Constants.ModelPath) | |||
| { | |||
| ContextSize = 2048 | |||
| }; | |||
| _model = LLamaWeights.LoadFromFile(_params); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| _model.Dispose(); | |||
| } | |||
| [Fact(Skip = "Very very slow in CI")] | |||
| public void BasicBeam() | |||
| { | |||
| const int num_beams = 2; | |||
| const int n_predict = 3; | |||
| var context = _model.CreateContext(_params); | |||
| var result = new StringBuilder(); | |||
| var initial_tokens = context.Tokenize("The cat sat on"); | |||
| result.Append(context.DeTokenize(initial_tokens.ToArray())); | |||
| context.Eval(initial_tokens, 0); | |||
| NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | |||
| { | |||
| for (var i = 0; i < state.Beams.Length; i++) | |||
| { | |||
| ref var view = ref state.Beams[i]; | |||
| var tokens = context.DeTokenize(view.Tokens.ToArray()); | |||
| _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | |||
| } | |||
| if (state.CommonPrefixLength > 0) | |||
| { | |||
| var view = state.Beams[0]; | |||
| result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); | |||
| } | |||
| }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | |||
| _testOutputHelper.WriteLine($"Final: {result}"); | |||
| } | |||
| } | |||
| @@ -66,7 +66,7 @@ namespace LLama.Unittest | |||
| Grammar = grammar, | |||
| }; | |||
| var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); | |||
| var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); | |||
| Assert.Equal("cat", result[0]); | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Information about a single beam in a beam search | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaBeamView | |||
| { | |||
| private readonly unsafe llama_token* tokens; | |||
| private readonly nint n_tokens; | |||
| /// <summary> | |||
| /// Cumulative beam probability (renormalized relative to all beams) | |||
| /// </summary> | |||
| public readonly float CumulativeProbability; | |||
| /// <summary> | |||
| /// Callback should set this to true when a beam is at end-of-beam. | |||
| /// </summary> | |||
| public bool EndOfBeam; | |||
| /// <summary> | |||
| /// Tokens in this beam | |||
| /// </summary> | |||
| public readonly Span<llama_token> Tokens | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| if (n_tokens > int.MaxValue) | |||
| throw new InvalidOperationException("More than 2147483647 tokens is not supported"); | |||
| return new Span<llama_token>(tokens, (int)n_tokens); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,49 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// Passed to beam_search_callback function. | |||
| /// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams | |||
| /// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public readonly struct LLamaBeamsState | |||
| { | |||
| /// <summary> | |||
| /// The state of each individual beam | |||
| /// </summary> | |||
| private readonly unsafe LLamaBeamView* beam_views; | |||
| /// <summary> | |||
| /// Number of elements in beam_views | |||
| /// </summary> | |||
| private readonly nint n_beams; | |||
| /// <summary> | |||
| /// Current max length of prefix tokens shared by all beams. | |||
| /// </summary> | |||
| public readonly ulong CommonPrefixLength; | |||
| /// <summary> | |||
| /// True iff this is the last callback invocation. | |||
| /// </summary> | |||
| public readonly bool LastCall; | |||
| /// <summary> | |||
| /// The current state of each beam | |||
| /// </summary> | |||
| public Span<LLamaBeamView> Beams | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| if (n_beams > int.MaxValue) | |||
| throw new InvalidOperationException("More than 2147483647 beams is not supported"); | |||
| return new Span<LLamaBeamView>(beam_views, (int)n_beams); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| public partial class NativeApi | |||
| { | |||
| /// <summary> | |||
| /// Type of pointer to the beam_search_callback function. | |||
| /// </summary> | |||
| /// <param name="callback_data">callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac</param> | |||
| /// <param name="state"></param> | |||
| public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state); | |||
| /// <summary>Deterministically returns entire sentence constructed by a beam search.</summary> | |||
| /// <param name="ctx">Pointer to the llama_context.</param> | |||
| /// <param name="callback">Invoked for each iteration of the beam_search loop, passing in beams_state.</param> | |||
| /// <param name="callback_data">A pointer that is simply passed back to callback.</param> | |||
| /// <param name="n_beams">Number of beams to use.</param> | |||
| /// <param name="n_past">Number of tokens already evaluated.</param> | |||
| /// <param name="n_predict">Maximum number of tokens to predict. EOS may occur earlier.</param> | |||
| /// <param name="n_threads">Number of threads.</param> | |||
| [DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads); | |||
| } | |||