You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BeamTests.cs 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using System.Text;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest;
  6. public sealed class BeamTests
  7. : IDisposable
  8. {
  9. private readonly ITestOutputHelper _testOutputHelper;
  10. private readonly ModelParams _params;
  11. private readonly LLamaWeights _model;
  12. public BeamTests(ITestOutputHelper testOutputHelper)
  13. {
  14. _testOutputHelper = testOutputHelper;
  15. _params = new ModelParams(Constants.ModelPath)
  16. {
  17. ContextSize = 2048
  18. };
  19. _model = LLamaWeights.LoadFromFile(_params);
  20. }
  21. public void Dispose()
  22. {
  23. _model.Dispose();
  24. }
  25. [Fact(Skip = "Very very slow in CI")]
  26. public void BasicBeam()
  27. {
  28. const int num_beams = 2;
  29. const int n_predict = 3;
  30. const string prompt = "The cat sat on";
  31. var context = _model.CreateContext(_params);
  32. var result = new StringBuilder();
  33. var initial_tokens = context.Tokenize(prompt);
  34. result.Append(prompt);
  35. //context.Eval(initial_tokens.AsSpan(), 0);
  36. throw new NotImplementedException("Replace Eval");
  37. NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
  38. {
  39. for (var i = 0; i < state.Beams.Length; i++)
  40. {
  41. ref var view = ref state.Beams[i];
  42. var decoder = new StreamingTokenDecoder(context);
  43. decoder.AddRange(view.Tokens);
  44. var tokens = decoder.Read();
  45. _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
  46. }
  47. if (state.CommonPrefixLength > 0)
  48. {
  49. var view = state.Beams[0];
  50. var decoder = new StreamingTokenDecoder(context);
  51. decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
  52. var tokens = decoder.Read();
  53. result.Append(tokens);
  54. }
  55. }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
  56. _testOutputHelper.WriteLine($"Final: {result}");
  57. }
  58. }