您最多选择25个标签 标签必须以中文、字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

LLamaEmbedderTests.cs 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. using LLama.Common;
  2. using Xunit.Abstractions;
  3. namespace LLama.Unittest;
  4. public sealed class LLamaEmbedderTests
  5. : IDisposable
  6. {
  7. private readonly ITestOutputHelper _testOutputHelper;
  8. private readonly LLamaEmbedder _embedder;
  9. public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
  10. {
  11. _testOutputHelper = testOutputHelper;
  12. var @params = new ModelParams(Constants.ModelPath)
  13. {
  14. EmbeddingMode = true,
  15. };
  16. using var weights = LLamaWeights.LoadFromFile(@params);
  17. _embedder = new(weights, @params);
  18. }
  19. public void Dispose()
  20. {
  21. _embedder.Dispose();
  22. }
  23. private static float Dot(float[] a, float[] b)
  24. {
  25. Assert.Equal(a.Length, b.Length);
  26. return a.Zip(b, (x, y) => x * y).Sum();
  27. }
  28. [Fact]
  29. public async Task EmbedCompare()
  30. {
  31. var cat = await _embedder.GetEmbeddings("The cat is cute");
  32. var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
  33. var spoon = await _embedder.GetEmbeddings("The spoon is not real");
  34. _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
  35. _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
  36. _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
  37. var close = 1 - Dot(cat, kitten);
  38. var far = 1 - Dot(cat, spoon);
  39. Assert.True(close < far);
  40. }
  41. }