- Rewritten native API methods for embeddings to return pointers - null is a valid value for these methods to return so `Span` is not appropriatepull/678/head
| @@ -1,32 +1,16 @@ | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using Xunit.Abstractions; | |||
| using Xunit.Sdk; | |||
| namespace LLama.Unittest; | |||
| public sealed class LLamaEmbedderTests | |||
| : IDisposable | |||
| { | |||
| private readonly ITestOutputHelper _testOutputHelper; | |||
| private readonly LLamaEmbedder _embedder; | |||
| public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) | |||
| { | |||
| _testOutputHelper = testOutputHelper; | |||
| var @params = new ModelParams(Constants.EmbeddingModelPath) | |||
| { | |||
| ContextSize = 4096, | |||
| Threads = 5, | |||
| Embeddings = true, | |||
| }; | |||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||
| _embedder = new(weights, @params); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| _embedder.Dispose(); | |||
| } | |||
| private static float Dot(float[] a, float[] b) | |||
| @@ -35,17 +19,24 @@ public sealed class LLamaEmbedderTests | |||
| return a.Zip(b, (x, y) => x * y).Sum(); | |||
| } | |||
| [Fact] | |||
| public async Task EmbedCompare() | |||
| private async Task CompareEmbeddings(string modelPath) | |||
| { | |||
| var cat = await _embedder.GetEmbeddings("The cat is cute"); | |||
| var @params = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 8, | |||
| Threads = 4, | |||
| Embeddings = true, | |||
| }; | |||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||
| using var embedder = new LLamaEmbedder(weights, @params); | |||
| var cat = await embedder.GetEmbeddings("The cat is cute"); | |||
| Assert.DoesNotContain(float.NaN, cat); | |||
| var kitten = await _embedder.GetEmbeddings("The kitten is kawaii"); | |||
| var kitten = await embedder.GetEmbeddings("The kitten is kawaii"); | |||
| Assert.DoesNotContain(float.NaN, kitten); | |||
| var spoon = await _embedder.GetEmbeddings("The spoon is not real"); | |||
| var spoon = await embedder.GetEmbeddings("The spoon is not real"); | |||
| Assert.DoesNotContain(float.NaN, spoon); | |||
| _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); | |||
| @@ -61,4 +52,16 @@ public sealed class LLamaEmbedderTests | |||
| Assert.True(close < far); | |||
| } | |||
| [Fact] | |||
| public async Task EmbedCompareEmbeddingModel() | |||
| { | |||
| await CompareEmbeddings(Constants.EmbeddingModelPath); | |||
| } | |||
| [Fact] | |||
| public async Task EmbedCompareGenerateModel() | |||
| { | |||
| await CompareEmbeddings(Constants.GenerativeModelPath); | |||
| } | |||
| } | |||
| @@ -97,15 +97,18 @@ namespace LLama | |||
| private float[] GetEmbeddingsArray() | |||
| { | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null || embeddings.Length == 0) | |||
| unsafe | |||
| { | |||
| embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); | |||
| if (embeddings == null || embeddings.Length == 0) | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null) | |||
| embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); | |||
| if (embeddings == null) | |||
| return Array.Empty<float>(); | |||
| } | |||
| return embeddings.ToArray(); | |||
| return new Span<float>(embeddings, Context.EmbeddingSize).ToArray(); | |||
| } | |||
| } | |||
| private static void Normalize(Span<float> embeddings) | |||
| @@ -116,6 +119,7 @@ namespace LLama | |||
| lengthSqr += value * value; | |||
| var length = (float)Math.Sqrt(lengthSqr); | |||
| // Do not divide by length if it is zero | |||
| if (length <= float.Epsilon) | |||
| return; | |||
| @@ -137,41 +137,17 @@ namespace LLama.Native | |||
| /// Get the embeddings for the a specific sequence. | |||
| /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Span<float> llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id) | |||
| { | |||
| unsafe | |||
| { | |||
| var ptr = llama_get_embeddings_seq_native(ctx, id); | |||
| if (ptr == null) | |||
| return Array.Empty<float>(); | |||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")] | |||
| static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id); | |||
| } | |||
| /// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); | |||
| /// <summary> | |||
| /// Get the embeddings for the ith sequence. | |||
| /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Span<float> llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i) | |||
| { | |||
| unsafe | |||
| { | |||
| var ptr = llama_get_embeddings_ith_native(ctx, i); | |||
| if (ptr == null) | |||
| return Array.Empty<float>(); | |||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")] | |||
| static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i); | |||
| } | |||
| /// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); | |||
| /// <summary> | |||
| /// Get all output token embeddings. | |||
| @@ -182,20 +158,8 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <returns></returns> | |||
| public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx) | |||
| { | |||
| unsafe | |||
| { | |||
| var ptr = llama_get_embeddings_native(ctx); | |||
| if (ptr == null) | |||
| return Array.Empty<float>(); | |||
| return new Span<float>(ptr, ctx.EmbeddingSize); | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||
| static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); | |||
| } | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern unsafe float* llama_get_embeddings(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Apply chat template. Inspired by hf apply_chat_template() on python. | |||