|
|
|
@@ -27,7 +27,8 @@ public sealed class BeamTests |
|
|
|
_model.Dispose(); |
|
|
|
} |
|
|
|
|
|
|
|
[Fact(Skip = "Very very slow in CI")] |
|
|
|
//[Fact(Skip = "Very very slow in CI")] |
|
|
|
[Fact] |
|
|
|
public void BasicBeam() |
|
|
|
{ |
|
|
|
const int num_beams = 2; |
|
|
|
@@ -36,15 +37,15 @@ public sealed class BeamTests |
|
|
|
|
|
|
|
var context = _model.CreateContext(_params); |
|
|
|
|
|
|
|
var result = new StringBuilder(); |
|
|
|
|
|
|
|
var initial_tokens = context.Tokenize(prompt); |
|
|
|
result.Append(prompt); |
|
|
|
//context.Eval(initial_tokens.AsSpan(), 0); |
|
|
|
throw new NotImplementedException("Replace Eval"); |
|
|
|
var batch = new LLamaBatch(); |
|
|
|
batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true); |
|
|
|
context.Decode(batch); |
|
|
|
|
|
|
|
var decoder = new StreamingTokenDecoder(context); |
|
|
|
NativeApi.llama_beam_search(context.NativeHandle, (data, state) => |
|
|
|
{ |
|
|
|
// Show the current state of every beam. |
|
|
|
for (var i = 0; i < state.Beams.Length; i++) |
|
|
|
{ |
|
|
|
ref var view = ref state.Beams[i]; |
|
|
|
@@ -56,20 +57,17 @@ public sealed class BeamTests |
|
|
|
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); |
|
|
|
} |
|
|
|
|
|
|
|
// Once all beams agree on some tokens read them and append them to the output decoder |
|
|
|
if (state.CommonPrefixLength > 0) |
|
|
|
{ |
|
|
|
var view = state.Beams[0]; |
|
|
|
|
|
|
|
var decoder = new StreamingTokenDecoder(context); |
|
|
|
decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); |
|
|
|
var tokens = decoder.Read(); |
|
|
|
|
|
|
|
result.Append(tokens); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); |
|
|
|
|
|
|
|
_testOutputHelper.WriteLine($"Final: {result}"); |
|
|
|
_testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}"); |
|
|
|
} |
|
|
|
} |