You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedderTests.cs 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using LLama.Common;
  2. namespace LLama.Unittest;
  3. public class LLamaEmbedderTests
  4. : IDisposable
  5. {
  6. private readonly LLamaEmbedder _embedder;
  7. public LLamaEmbedderTests()
  8. {
  9. var @params = new ModelParams(Constants.ModelPath);
  10. using var weights = LLamaWeights.LoadFromFile(@params);
  11. _embedder = new(weights, @params);
  12. }
  13. public void Dispose()
  14. {
  15. _embedder.Dispose();
  16. }
  17. private static float Magnitude(float[] a)
  18. {
  19. return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum());
  20. }
  21. private static void Normalize(float[] a)
  22. {
  23. var mag = Magnitude(a);
  24. for (var i = 0; i < a.Length; i++)
  25. a[i] /= mag;
  26. }
  27. private static float Dot(float[] a, float[] b)
  28. {
  29. Assert.Equal(a.Length, b.Length);
  30. return a.Zip(b, (x, y) => x * y).Sum();
  31. }
  32. private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f)
  33. {
  34. for (int i = 0; i < expected.Length; i++)
  35. Assert.Equal(expected[i], actual[i], epsilon);
  36. }
  37. // todo: enable this one llama2 7B gguf is available
  38. //[Fact]
  39. //public void EmbedBasic()
  40. //{
  41. // var cat = _embedder.GetEmbeddings("cat");
  42. // Assert.NotNull(cat);
  43. // Assert.NotEmpty(cat);
  44. // // Expected value generate with llama.cpp embedding.exe
  45. // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
  46. // AssertApproxStartsWith(expected, cat);
  47. //}
  48. [Fact]
  49. public void EmbedCompare()
  50. {
  51. var cat = _embedder.GetEmbeddings("cat");
  52. var kitten = _embedder.GetEmbeddings("kitten");
  53. var spoon = _embedder.GetEmbeddings("spoon");
  54. Normalize(cat);
  55. Normalize(kitten);
  56. Normalize(spoon);
  57. var close = Dot(cat, kitten);
  58. var far = Dot(cat, spoon);
  59. // This comparison seems backwards, but remember that with a
  60. // dot product 1.0 means **identical** and 0.0 means **completely opposite**!
  61. Assert.True(close > far);
  62. }
  63. }