| @@ -27,7 +27,8 @@ public sealed class BeamTests | |||||
| _model.Dispose(); | _model.Dispose(); | ||||
| } | } | ||||
| [Fact(Skip = "Very very slow in CI")] | |||||
| //[Fact(Skip = "Very very slow in CI")] | |||||
| [Fact] | |||||
| public void BasicBeam() | public void BasicBeam() | ||||
| { | { | ||||
| const int num_beams = 2; | const int num_beams = 2; | ||||
| @@ -36,15 +37,15 @@ public sealed class BeamTests | |||||
| var context = _model.CreateContext(_params); | var context = _model.CreateContext(_params); | ||||
| var result = new StringBuilder(); | |||||
| var initial_tokens = context.Tokenize(prompt); | var initial_tokens = context.Tokenize(prompt); | ||||
| result.Append(prompt); | |||||
| //context.Eval(initial_tokens.AsSpan(), 0); | |||||
| throw new NotImplementedException("Replace Eval"); | |||||
| var batch = new LLamaBatch(); | |||||
| batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true); | |||||
| context.Decode(batch); | |||||
| var decoder = new StreamingTokenDecoder(context); | |||||
| NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | ||||
| { | { | ||||
| // Show the current state of every beam. | |||||
| for (var i = 0; i < state.Beams.Length; i++) | for (var i = 0; i < state.Beams.Length; i++) | ||||
| { | { | ||||
| ref var view = ref state.Beams[i]; | ref var view = ref state.Beams[i]; | ||||
| @@ -56,20 +57,17 @@ public sealed class BeamTests | |||||
| _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | ||||
| } | } | ||||
| // Once all beams agree on some tokens read them and append them to the output decoder | |||||
| if (state.CommonPrefixLength > 0) | if (state.CommonPrefixLength > 0) | ||||
| { | { | ||||
| var view = state.Beams[0]; | var view = state.Beams[0]; | ||||
| var decoder = new StreamingTokenDecoder(context); | |||||
| decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); | decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); | ||||
| var tokens = decoder.Read(); | |||||
| result.Append(tokens); | |||||
| } | } | ||||
| }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | ||||
| _testOutputHelper.WriteLine($"Final: {result}"); | |||||
| _testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}"); | |||||
| } | } | ||||
| } | } | ||||