You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BeamTests.cs 2.2 kB

April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using System.Text;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest;
  6. public sealed class BeamTests
  7. : IDisposable
  8. {
  9. private readonly ITestOutputHelper _testOutputHelper;
  10. private readonly ModelParams _params;
  11. private readonly LLamaWeights _model;
  12. public BeamTests(ITestOutputHelper testOutputHelper)
  13. {
  14. _testOutputHelper = testOutputHelper;
  15. _params = new ModelParams(Constants.GenerativeModelPath)
  16. {
  17. ContextSize = 2048,
  18. GpuLayerCount = Constants.CIGpuLayerCount,
  19. };
  20. _model = LLamaWeights.LoadFromFile(_params);
  21. }
  22. public void Dispose()
  23. {
  24. _model.Dispose();
  25. }
  26. [Fact]
  27. public void BasicBeam()
  28. {
  29. const int num_beams = 2;
  30. const int n_predict = 3;
  31. const string prompt = "The cat sat on";
  32. var context = _model.CreateContext(_params);
  33. var initial_tokens = context.Tokenize(prompt);
  34. var batch = new LLamaBatch();
  35. batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true);
  36. context.Decode(batch);
  37. var decoder = new StreamingTokenDecoder(context);
  38. NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
  39. {
  40. // Show the current state of every beam.
  41. for (var i = 0; i < state.Beams.Length; i++)
  42. {
  43. ref var view = ref state.Beams[i];
  44. var decoder = new StreamingTokenDecoder(context);
  45. decoder.AddRange(view.Tokens);
  46. var tokens = decoder.Read();
  47. _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
  48. }
  49. // Once all beams agree on some tokens read them and append them to the output decoder
  50. if (state.CommonPrefixLength > 0)
  51. {
  52. var view = state.Beams[0];
  53. decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
  54. }
  55. }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));
  56. _testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}");
  57. }
  58. }