You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StatelessExecutorTest.cs 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using System.Diagnostics;
  2. using LLama.Common;
  3. using LLama.Sampling;
  4. using Xunit.Abstractions;
  5. namespace LLama.Unittest
  6. {
  7. public class StatelessExecutorTest
  8. : IDisposable
  9. {
  10. private readonly ITestOutputHelper _testOutputHelper;
  11. private readonly LLamaWeights _weights;
  12. private readonly ModelParams _params;
  13. public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
  14. {
  15. _testOutputHelper = testOutputHelper;
  16. _params = new ModelParams(Constants.ModelPath)
  17. {
  18. ContextSize = 60,
  19. Seed = 1754,
  20. };
  21. _weights = LLamaWeights.LoadFromFile(_params);
  22. }
  23. public void Dispose()
  24. {
  25. _weights.Dispose();
  26. }
  27. [Fact]
  28. public async Task Stateless()
  29. {
  30. // Create a custom pipeline that mimics the default pipeline
  31. var pipeline = new DefaultSamplingPipeline();
  32. var executor = new StatelessExecutor(_weights, _params);
  33. const string question = "Question. what is a cat?\nAnswer: ";
  34. var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
  35. var timer = new Stopwatch();
  36. timer.Start();
  37. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  38. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  39. timer.Stop();
  40. _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
  41. _testOutputHelper.WriteLine(result1);
  42. _testOutputHelper.WriteLine(result2);
  43. // Check that it produced the exact same result both times
  44. Assert.Equal(result1, result2);
  45. }
  46. [Fact(Skip = "Very very slow in CI")]
  47. public async Task OutOfContext()
  48. {
  49. var executor = new StatelessExecutor(_weights, _params);
  50. const string question = " Question. cats or dogs?\nAnswer: ";
  51. // The context size is set to 60. Generate more than that, forcing it to generate a coherent response
  52. // with a modified context
  53. var @params = new InferenceParams()
  54. {
  55. MaxTokens = 65,
  56. TokensKeep = question.Length,
  57. };
  58. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  59. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  60. _testOutputHelper.WriteLine(result1);
  61. // Check that it produced the exact same result both times
  62. Assert.Equal(result1, result2);
  63. }
  64. }
  65. }