You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BeamTests.cs 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. using System.Text;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest;
  6. public sealed class BeamTests
  7. : IDisposable
  8. {
  9. private readonly ITestOutputHelper _testOutputHelper;
  10. private readonly ModelParams _params;
  11. private readonly LLamaWeights _model;
  12. public BeamTests(ITestOutputHelper testOutputHelper)
  13. {
  14. _testOutputHelper = testOutputHelper;
  15. _params = new ModelParams(Constants.ModelPath)
  16. {
  17. ContextSize = 2048
  18. };
  19. _model = LLamaWeights.LoadFromFile(_params);
  20. }
  21. public void Dispose()
  22. {
  23. _model.Dispose();
  24. }
  25. [Fact(Skip = "Very very slow in CI")]
  26. public void BasicBeam()
  27. {
  28. const int num_beams = 2;
  29. const int n_predict = 3;
  30. var context = _model.CreateContext(_params);
  31. var result = new StringBuilder();
  32. var initial_tokens = context.Tokenize("The cat sat on");
  33. result.Append(context.DeTokenize(initial_tokens.ToArray()));
  34. context.Eval(initial_tokens, 0);
  35. NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
  36. {
  37. for (var i = 0; i < state.Beams.Length; i++)
  38. {
  39. ref var view = ref state.Beams[i];
  40. var tokens = context.DeTokenize(view.Tokens.ToArray());
  41. _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
  42. }
  43. if (state.CommonPrefixLength > 0)
  44. {
  45. var view = state.Beams[0];
  46. result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
  47. }
  48. }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
  49. _testOutputHelper.WriteLine($"Final: {result}");
  50. }
  51. }