You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StatelessExecutorTest.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. using System.Diagnostics;
  2. using LLama.Common;
  3. using LLama.Sampling;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest
  6. {
  7. public class StatelessExecutorTest
  8. : IDisposable
  9. {
  10. private readonly ITestOutputHelper _testOutputHelper;
  11. private readonly LLamaWeights _weights;
  12. private readonly ModelParams _params;
  13. public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
  14. {
  15. _testOutputHelper = testOutputHelper;
  16. _params = new ModelParams(Constants.ModelPath)
  17. {
  18. ContextSize = 60,
  19. Seed = 1754,
  20. BatchSize = 2,
  21. };
  22. _weights = LLamaWeights.LoadFromFile(_params);
  23. }
  24. public void Dispose()
  25. {
  26. _weights.Dispose();
  27. }
  28. [Fact]
  29. public async Task Stateless()
  30. {
  31. // Create a custom pipeline that mimics the default pipeline
  32. var pipeline = new DefaultSamplingPipeline();
  33. var executor = new StatelessExecutor(_weights, _params);
  34. const string question = "Question. what is a cat?\nAnswer:";
  35. var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
  36. var timer = new Stopwatch();
  37. timer.Start();
  38. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  39. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  40. timer.Stop();
  41. _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
  42. _testOutputHelper.WriteLine(result1);
  43. _testOutputHelper.WriteLine(result2);
  44. // Check that it produced the exact same result both times
  45. Assert.Equal(result1, result2);
  46. }
  47. [Fact(Skip = "Very very slow in CI")]
  48. public async Task OutOfContext()
  49. {
  50. var executor = new StatelessExecutor(_weights, _params);
  51. const string question = " Question. cats or dogs?\nAnswer:";
  52. // The context size is set to 60. Generate more than that, forcing it to generate a coherent response
  53. // with a modified context
  54. var @params = new InferenceParams()
  55. {
  56. MaxTokens = 65,
  57. TokensKeep = question.Length,
  58. };
  59. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  60. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  61. _testOutputHelper.WriteLine(result1);
  62. // Check that it produced the exact same result both times
  63. Assert.Equal(result1, result2);
  64. }
  65. }
  66. }