diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index be21a5f2..83eb87d3 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -27,7 +27,8 @@ public sealed class BeamTests _model.Dispose(); } - [Fact(Skip = "Very very slow in CI")] + //[Fact(Skip = "Very very slow in CI")] + [Fact] public void BasicBeam() { const int num_beams = 2; @@ -36,15 +37,15 @@ public sealed class BeamTests var context = _model.CreateContext(_params); - var result = new StringBuilder(); - var initial_tokens = context.Tokenize(prompt); - result.Append(prompt); - //context.Eval(initial_tokens.AsSpan(), 0); - throw new NotImplementedException("Replace Eval"); + var batch = new LLamaBatch(); + batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true); + context.Decode(batch); + var decoder = new StreamingTokenDecoder(context); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => { + // Show the current state of every beam. for (var i = 0; i < state.Beams.Length; i++) { ref var view = ref state.Beams[i]; @@ -56,20 +57,17 @@ public sealed class BeamTests _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); } + // Once all beams agree on some tokens read them and append them to the output decoder if (state.CommonPrefixLength > 0) { var view = state.Beams[0]; - var decoder = new StreamingTokenDecoder(context); decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); - var tokens = decoder.Read(); - - result.Append(tokens); } }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); - _testOutputHelper.WriteLine($"Final: {result}"); + _testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}"); } } \ No newline at end of file