您最多选择25个标签 标签必须以中文、字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

BeamTests.cs 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. using System.Text;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest;
  6. public sealed class BeamTests
  7. : IDisposable
  8. {
  9. private readonly ITestOutputHelper _testOutputHelper;
  10. private readonly ModelParams _params;
  11. private readonly LLamaWeights _model;
  12. public BeamTests(ITestOutputHelper testOutputHelper)
  13. {
  14. _testOutputHelper = testOutputHelper;
  15. _params = new ModelParams(Constants.ModelPath)
  16. {
  17. ContextSize = 2048
  18. };
  19. _model = LLamaWeights.LoadFromFile(_params);
  20. }
  21. public void Dispose()
  22. {
  23. _model.Dispose();
  24. }
  25. [Fact(Skip = "Very very slow in CI")]
  26. public void BasicBeam()
  27. {
  28. const int num_beams = 2;
  29. const int n_predict = 3;
  30. const string prompt = "The cat sat on";
  31. var context = _model.CreateContext(_params);
  32. var result = new StringBuilder();
  33. var initial_tokens = context.Tokenize(prompt);
  34. result.Append(prompt);
  35. context.Eval(initial_tokens.AsSpan(), 0);
  36. NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
  37. {
  38. for (var i = 0; i < state.Beams.Length; i++)
  39. {
  40. ref var view = ref state.Beams[i];
  41. var decoder = new StreamingTokenDecoder(context);
  42. decoder.AddRange(view.Tokens);
  43. var tokens = decoder.Read();
  44. _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
  45. }
  46. if (state.CommonPrefixLength > 0)
  47. {
  48. var view = state.Beams[0];
  49. var decoder = new StreamingTokenDecoder(context);
  50. decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
  51. var tokens = decoder.Read();
  52. result.Append(tokens);
  53. }
  54. }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
  55. _testOutputHelper.WriteLine($"Final: {result}");
  56. }
  57. }