- Rewritten native API methods for embeddings to return pointers - null is a valid value for these methods to return so `Span` is not appropriatepull/678/head
| @@ -1,32 +1,16 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | |||||
| using Xunit.Abstractions; | using Xunit.Abstractions; | ||||
| using Xunit.Sdk; | |||||
| namespace LLama.Unittest; | namespace LLama.Unittest; | ||||
| public sealed class LLamaEmbedderTests | public sealed class LLamaEmbedderTests | ||||
| : IDisposable | |||||
| { | { | ||||
| private readonly ITestOutputHelper _testOutputHelper; | private readonly ITestOutputHelper _testOutputHelper; | ||||
| private readonly LLamaEmbedder _embedder; | |||||
| public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) | public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) | ||||
| { | { | ||||
| _testOutputHelper = testOutputHelper; | _testOutputHelper = testOutputHelper; | ||||
| var @params = new ModelParams(Constants.EmbeddingModelPath) | |||||
| { | |||||
| ContextSize = 4096, | |||||
| Threads = 5, | |||||
| Embeddings = true, | |||||
| }; | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| _embedder = new(weights, @params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _embedder.Dispose(); | |||||
| } | } | ||||
| private static float Dot(float[] a, float[] b) | private static float Dot(float[] a, float[] b) | ||||
| @@ -35,17 +19,24 @@ public sealed class LLamaEmbedderTests | |||||
| return a.Zip(b, (x, y) => x * y).Sum(); | return a.Zip(b, (x, y) => x * y).Sum(); | ||||
| } | } | ||||
| [Fact] | |||||
| public async Task EmbedCompare() | |||||
| private async Task CompareEmbeddings(string modelPath) | |||||
| { | { | ||||
| var cat = await _embedder.GetEmbeddings("The cat is cute"); | |||||
| var @params = new ModelParams(modelPath) | |||||
| { | |||||
| ContextSize = 8, | |||||
| Threads = 4, | |||||
| Embeddings = true, | |||||
| }; | |||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| using var embedder = new LLamaEmbedder(weights, @params); | |||||
| var cat = await embedder.GetEmbeddings("The cat is cute"); | |||||
| Assert.DoesNotContain(float.NaN, cat); | Assert.DoesNotContain(float.NaN, cat); | ||||
| var kitten = await _embedder.GetEmbeddings("The kitten is kawaii"); | |||||
| var kitten = await embedder.GetEmbeddings("The kitten is kawaii"); | |||||
| Assert.DoesNotContain(float.NaN, kitten); | Assert.DoesNotContain(float.NaN, kitten); | ||||
| var spoon = await _embedder.GetEmbeddings("The spoon is not real"); | |||||
| var spoon = await embedder.GetEmbeddings("The spoon is not real"); | |||||
| Assert.DoesNotContain(float.NaN, spoon); | Assert.DoesNotContain(float.NaN, spoon); | ||||
| _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); | _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); | ||||
| @@ -61,4 +52,16 @@ public sealed class LLamaEmbedderTests | |||||
| Assert.True(close < far); | Assert.True(close < far); | ||||
| } | } | ||||
| [Fact] | |||||
| public async Task EmbedCompareEmbeddingModel() | |||||
| { | |||||
| await CompareEmbeddings(Constants.EmbeddingModelPath); | |||||
| } | |||||
| [Fact] | |||||
| public async Task EmbedCompareGenerateModel() | |||||
| { | |||||
| await CompareEmbeddings(Constants.GenerativeModelPath); | |||||
| } | |||||
| } | } | ||||
| @@ -97,15 +97,18 @@ namespace LLama | |||||
| private float[] GetEmbeddingsArray() | private float[] GetEmbeddingsArray() | ||||
| { | { | ||||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||||
| if (embeddings == null || embeddings.Length == 0) | |||||
| unsafe | |||||
| { | { | ||||
| embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); | |||||
| if (embeddings == null || embeddings.Length == 0) | |||||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||||
| if (embeddings == null) | |||||
| embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); | |||||
| if (embeddings == null) | |||||
| return Array.Empty<float>(); | return Array.Empty<float>(); | ||||
| } | |||||
| return embeddings.ToArray(); | |||||
| return new Span<float>(embeddings, Context.EmbeddingSize).ToArray(); | |||||
| } | |||||
| } | } | ||||
| private static void Normalize(Span<float> embeddings) | private static void Normalize(Span<float> embeddings) | ||||
| @@ -116,6 +119,7 @@ namespace LLama | |||||
| lengthSqr += value * value; | lengthSqr += value * value; | ||||
| var length = (float)Math.Sqrt(lengthSqr); | var length = (float)Math.Sqrt(lengthSqr); | ||||
| // Do not divide by length if it is zero | |||||
| if (length <= float.Epsilon) | if (length <= float.Epsilon) | ||||
| return; | return; | ||||
| @@ -137,41 +137,17 @@ namespace LLama.Native | |||||
| /// Get the embeddings for the a specific sequence. | /// Get the embeddings for the a specific sequence. | ||||
| /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | |||||
| public static Span<float> llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var ptr = llama_get_embeddings_seq_native(ctx, id); | |||||
| if (ptr == null) | |||||
| return Array.Empty<float>(); | |||||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||||
| } | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")] | |||||
| static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id); | |||||
| } | |||||
| /// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the embeddings for the ith sequence. | /// Get the embeddings for the ith sequence. | ||||
| /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | |||||
| public static Span<float> llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var ptr = llama_get_embeddings_ith_native(ctx, i); | |||||
| if (ptr == null) | |||||
| return Array.Empty<float>(); | |||||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||||
| } | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")] | |||||
| static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i); | |||||
| } | |||||
| /// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get all output token embeddings. | /// Get all output token embeddings. | ||||
| @@ -182,20 +158,8 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var ptr = llama_get_embeddings_native(ctx); | |||||
| if (ptr == null) | |||||
| return Array.Empty<float>(); | |||||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||||
| } | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||||
| static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); | |||||
| } | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern unsafe float* llama_get_embeddings(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply chat template. Inspired by hf apply_chat_template() on python. | /// Apply chat template. Inspired by hf apply_chat_template() on python. | ||||