using System.Text; using LLama.Common; using LLama.Native; using Xunit.Abstractions; namespace LLama.Unittest; public sealed class BeamTests : IDisposable { private readonly ITestOutputHelper _testOutputHelper; private readonly ModelParams _params; private readonly LLamaWeights _model; public BeamTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; _params = new ModelParams(Constants.ModelPath) { ContextSize = 2048 }; _model = LLamaWeights.LoadFromFile(_params); } public void Dispose() { _model.Dispose(); } [Fact(Skip = "Very very slow in CI")] public void BasicBeam() { const int num_beams = 2; const int n_predict = 3; const string prompt = "The cat sat on"; var context = _model.CreateContext(_params); var result = new StringBuilder(); var initial_tokens = context.Tokenize(prompt); result.Append(prompt); context.Eval(initial_tokens, 0); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => { for (var i = 0; i < state.Beams.Length; i++) { ref var view = ref state.Beams[i]; var tokens = context.DeTokenize(view.Tokens.ToArray()); _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); } if (state.CommonPrefixLength > 0) { var view = state.Beams[0]; result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); } }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); _testOutputHelper.WriteLine($"Final: {result}"); } }