using LLama.Common; using Xunit.Abstractions; namespace LLama.Unittest; public sealed class LLamaEmbedderTests { private readonly ITestOutputHelper _testOutputHelper; public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; } private static float Dot(float[] a, float[] b) { Assert.Equal(a.Length, b.Length); return a.Zip(b, (x, y) => x * y).Sum(); } private async Task CompareEmbeddings(string modelPath) { var @params = new ModelParams(modelPath) { ContextSize = 8, Threads = 4, Embeddings = true, GpuLayerCount = Constants.CIGpuLayerCount, }; using var weights = LLamaWeights.LoadFromFile(@params); using var embedder = new LLamaEmbedder(weights, @params); var cat = await embedder.GetEmbeddings("The cat is cute"); Assert.DoesNotContain(float.NaN, cat); var kitten = await embedder.GetEmbeddings("The kitten is kawaii"); Assert.DoesNotContain(float.NaN, kitten); var spoon = await embedder.GetEmbeddings("The spoon is not real"); Assert.DoesNotContain(float.NaN, spoon); _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); var close = 1 - Dot(cat, kitten); var far = 1 - Dot(cat, spoon); _testOutputHelper.WriteLine(""); _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); Assert.True(close < far); } [Fact] public async Task EmbedCompareEmbeddingModel() { await CompareEmbeddings(Constants.EmbeddingModelPath); } [Fact] public async Task EmbedCompareGenerateModel() { await CompareEmbeddings(Constants.GenerativeModelPath); } }