You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedderTests.cs 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. using LLama.Common;
  2. using Xunit.Abstractions;
  3. namespace LLama.Unittest;
  4. public sealed class LLamaEmbedderTests
  5. : IDisposable
  6. {
  7. private readonly ITestOutputHelper _testOutputHelper;
  8. private readonly LLamaEmbedder _embedder;
  9. public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
  10. {
  11. _testOutputHelper = testOutputHelper;
  12. var @params = new ModelParams(Constants.ModelPath)
  13. {
  14. ContextSize = 4096,
  15. Threads = 5,
  16. EmbeddingMode = true,
  17. };
  18. using var weights = LLamaWeights.LoadFromFile(@params);
  19. _embedder = new(weights, @params);
  20. }
  21. public void Dispose()
  22. {
  23. _embedder.Dispose();
  24. }
  25. private static float Dot(float[] a, float[] b)
  26. {
  27. Assert.Equal(a.Length, b.Length);
  28. return a.Zip(b, (x, y) => x * y).Sum();
  29. }
  30. [Fact]
  31. public async Task EmbedCompare()
  32. {
  33. var cat = await _embedder.GetEmbeddings("The cat is cute");
  34. var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
  35. var spoon = await _embedder.GetEmbeddings("The spoon is not real");
  36. _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
  37. _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
  38. _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
  39. var close = 1 - Dot(cat, kitten);
  40. var far = 1 - Dot(cat, spoon);
  41. Assert.True(close < far);
  42. }
  43. }