You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StatelessExecutorTest.cs 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. using LLama.Common;
  2. using Xunit.Abstractions;
  3. namespace LLama.Unittest
  4. {
  5. public class StatelessExecutorTest
  6. : IDisposable
  7. {
  8. private readonly ITestOutputHelper _testOutputHelper;
  9. private readonly LLamaWeights _weights;
  10. private readonly ModelParams _params;
  11. public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
  12. {
  13. _testOutputHelper = testOutputHelper;
  14. _params = new ModelParams(Constants.ModelPath)
  15. {
  16. ContextSize = 60,
  17. Seed = 1754,
  18. };
  19. _weights = LLamaWeights.LoadFromFile(_params);
  20. }
  21. public void Dispose()
  22. {
  23. _weights.Dispose();
  24. }
  25. [Fact]
  26. public async Task Stateless()
  27. {
  28. var executor = new StatelessExecutor(_weights, _params);
  29. const string question = "Question. what is a cat?\nAnswer: ";
  30. var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
  31. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  32. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  33. _testOutputHelper.WriteLine(result1);
  34. // Check that it produced the exact same result both times
  35. Assert.Equal(result1, result2);
  36. }
  37. [Fact(Skip = "Very very slow in CI")]
  38. public async Task OutOfContext()
  39. {
  40. var executor = new StatelessExecutor(_weights, _params);
  41. const string question = " Question. cats or dogs?\nAnswer: ";
  42. // The context size is set to 60. Generate more than that, forcing it to generate a coherent response
  43. // with a modified context
  44. var @params = new InferenceParams()
  45. {
  46. MaxTokens = 65,
  47. TokensKeep = question.Length,
  48. };
  49. var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  50. var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
  51. _testOutputHelper.WriteLine(result1);
  52. // Check that it produced the exact same result both times
  53. Assert.Equal(result1, result2);
  54. }
  55. }
  56. }