You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StatelessExecutorTest.cs 2.9 kB

April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. using System.Diagnostics;
  2. using LLama.Common;
  3. using LLama.Sampling;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest
  6. {
  7. public class StatelessExecutorTest
  8. : IDisposable
  9. {
  10. private readonly ITestOutputHelper _testOutputHelper;
  11. private readonly LLamaWeights _weights;
  12. private readonly ModelParams _params;
  13. public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
  14. {
  15. _testOutputHelper = testOutputHelper;
  16. _params = new ModelParams(Constants.GenerativeModelPath)
  17. {
  18. ContextSize = 60,
  19. Seed = 1754,
  20. BatchSize = 2,
  21. GpuLayerCount = Constants.CIGpuLayerCount,
  22. };
  23. _weights = LLamaWeights.LoadFromFile(_params);
  24. }
  25. public void Dispose()
  26. {
  27. _weights.Dispose();
  28. }
  29. [Fact]
  30. public async Task Stateless()
  31. {
  32. // Create a custom pipeline that mimics the default pipeline
  33. var pipeline = new DefaultSamplingPipeline();
  34. var executor = new StatelessExecutor(_weights, _params);
  35. const string question = "Question. what is a cat?\nAnswer:";
  36. var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
  37. var timer = new Stopwatch();
  38. timer.Start();
  39. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  40. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  41. timer.Stop();
  42. _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
  43. _testOutputHelper.WriteLine(result1);
  44. _testOutputHelper.WriteLine(result2);
  45. // Check that it produced the exact same result both times
  46. Assert.Equal(result1, result2);
  47. }
  48. [Fact(Skip = "Very very slow in CI")]
  49. public async Task OutOfContext()
  50. {
  51. var executor = new StatelessExecutor(_weights, _params);
  52. const string question = " Question. cats or dogs?\nAnswer:";
  53. // The context size is set to 60. Generate more than that, forcing it to generate a coherent response
  54. // with a modified context
  55. var @params = new InferenceParams()
  56. {
  57. MaxTokens = 65,
  58. TokensKeep = question.Length,
  59. };
  60. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  61. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  62. _testOutputHelper.WriteLine(result1);
  63. // Check that it produced the exact same result both times
  64. Assert.Equal(result1, result2);
  65. }
  66. }
  67. }