Browse Source

- Fixed beam search test to use decode

- Enabled it in CI
pull/665/head
Martin Evans 1 year ago
parent
commit
c760cb5f16
1 changed files with 9 additions and 11 deletions
  1. +9
    -11
      LLama.Unittest/BeamTests.cs

+ 9
- 11
LLama.Unittest/BeamTests.cs View File

@@ -27,7 +27,8 @@ public sealed class BeamTests
_model.Dispose(); _model.Dispose();
} }


[Fact(Skip = "Very very slow in CI")]
//[Fact(Skip = "Very very slow in CI")]
[Fact]
public void BasicBeam() public void BasicBeam()
{ {
const int num_beams = 2; const int num_beams = 2;
@@ -36,15 +37,15 @@ public sealed class BeamTests


var context = _model.CreateContext(_params); var context = _model.CreateContext(_params);


var result = new StringBuilder();

var initial_tokens = context.Tokenize(prompt); var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
//context.Eval(initial_tokens.AsSpan(), 0);
throw new NotImplementedException("Replace Eval");
var batch = new LLamaBatch();
batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true);
context.Decode(batch);


var decoder = new StreamingTokenDecoder(context);
NativeApi.llama_beam_search(context.NativeHandle, (data, state) => NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{ {
// Show the current state of every beam.
for (var i = 0; i < state.Beams.Length; i++) for (var i = 0; i < state.Beams.Length; i++)
{ {
ref var view = ref state.Beams[i]; ref var view = ref state.Beams[i];
@@ -56,20 +57,17 @@ public sealed class BeamTests
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
} }


// Once all beams agree on some tokens read them and append them to the output decoder
if (state.CommonPrefixLength > 0) if (state.CommonPrefixLength > 0)
{ {
var view = state.Beams[0]; var view = state.Beams[0];


var decoder = new StreamingTokenDecoder(context);
decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
var tokens = decoder.Read();

result.Append(tokens);
} }


}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));


_testOutputHelper.WriteLine($"Final: {result}");
_testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}");
} }
} }

Loading…
Cancel
Save