Browse Source

- Added tests for generating embeddings with generative model and embedding model

- Rewritten native API methods for embeddings to return pointers - null is a valid value for these methods to return so `Span` is not appropriate
pull/678/head
Martin Evans 1 year ago
parent
commit
3c76440957
3 changed files with 44 additions and 73 deletions
  1. +26
    -23
      LLama.Unittest/LLamaEmbedderTests.cs
  2. +10
    -6
      LLama/LLamaEmbedder.cs
  3. +8
    -44
      LLama/Native/NativeApi.cs

+ 26
- 23
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -1,32 +1,16 @@
using LLama.Common; using LLama.Common;
using LLama.Native;
using Xunit.Abstractions; using Xunit.Abstractions;
using Xunit.Sdk;


namespace LLama.Unittest; namespace LLama.Unittest;


public sealed class LLamaEmbedderTests public sealed class LLamaEmbedderTests
: IDisposable
{ {
private readonly ITestOutputHelper _testOutputHelper; private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaEmbedder _embedder;


public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{ {
_testOutputHelper = testOutputHelper; _testOutputHelper = testOutputHelper;
var @params = new ModelParams(Constants.EmbeddingModelPath)
{
ContextSize = 4096,
Threads = 5,
Embeddings = true,
};
using var weights = LLamaWeights.LoadFromFile(@params);
_embedder = new(weights, @params);
}

public void Dispose()
{
_embedder.Dispose();
} }


private static float Dot(float[] a, float[] b) private static float Dot(float[] a, float[] b)
@@ -35,17 +19,24 @@ public sealed class LLamaEmbedderTests
return a.Zip(b, (x, y) => x * y).Sum(); return a.Zip(b, (x, y) => x * y).Sum();
} }



[Fact]
public async Task EmbedCompare()
private async Task CompareEmbeddings(string modelPath)
{ {
var cat = await _embedder.GetEmbeddings("The cat is cute");
var @params = new ModelParams(modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);

var cat = await embedder.GetEmbeddings("The cat is cute");
Assert.DoesNotContain(float.NaN, cat); Assert.DoesNotContain(float.NaN, cat);


var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
Assert.DoesNotContain(float.NaN, kitten); Assert.DoesNotContain(float.NaN, kitten);


var spoon = await _embedder.GetEmbeddings("The spoon is not real");
var spoon = await embedder.GetEmbeddings("The spoon is not real");
Assert.DoesNotContain(float.NaN, spoon); Assert.DoesNotContain(float.NaN, spoon);


_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
@@ -61,4 +52,16 @@ public sealed class LLamaEmbedderTests


Assert.True(close < far); Assert.True(close < far);
} }

[Fact]
public async Task EmbedCompareEmbeddingModel()
{
await CompareEmbeddings(Constants.EmbeddingModelPath);
}

[Fact]
public async Task EmbedCompareGenerateModel()
{
await CompareEmbeddings(Constants.GenerativeModelPath);
}
} }

+ 10
- 6
LLama/LLamaEmbedder.cs View File

@@ -97,15 +97,18 @@ namespace LLama


private float[] GetEmbeddingsArray() private float[] GetEmbeddingsArray()
{ {
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null || embeddings.Length == 0)
unsafe
{ {
embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
if (embeddings == null || embeddings.Length == 0)
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);

if (embeddings == null)
embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);

if (embeddings == null)
return Array.Empty<float>(); return Array.Empty<float>();
}


return embeddings.ToArray();
return new Span<float>(embeddings, Context.EmbeddingSize).ToArray();
}
} }


private static void Normalize(Span<float> embeddings) private static void Normalize(Span<float> embeddings)
@@ -116,6 +119,7 @@ namespace LLama
lengthSqr += value * value; lengthSqr += value * value;
var length = (float)Math.Sqrt(lengthSqr); var length = (float)Math.Sqrt(lengthSqr);


// Do not divide by length if it is zero
if (length <= float.Epsilon) if (length <= float.Epsilon)
return; return;




+ 8
- 44
LLama/Native/NativeApi.cs View File

@@ -137,41 +137,17 @@ namespace LLama.Native
/// Get the embeddings for the a specific sequence. /// Get the embeddings for the a specific sequence.
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
/// </summary> /// </summary>
/// <returns></returns>
public static Span<float> llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id)
{
unsafe
{
var ptr = llama_get_embeddings_seq_native(ctx, id);
if (ptr == null)
return Array.Empty<float>();

return new Span<float>(ptr, ctx.EmbeddingSize);
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")]
static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id);
}
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id);


/// <summary> /// <summary>
/// Get the embeddings for the ith sequence. /// Get the embeddings for the ith sequence.
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
/// </summary> /// </summary>
/// <returns></returns>
public static Span<float> llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i)
{
unsafe
{
var ptr = llama_get_embeddings_ith_native(ctx, i);
if (ptr == null)
return Array.Empty<float>();

return new Span<float>(ptr, ctx.EmbeddingSize);
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")]
static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i);
}
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i);


/// <summary> /// <summary>
/// Get all output token embeddings. /// Get all output token embeddings.
@@ -182,20 +158,8 @@ namespace LLama.Native
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns></returns> /// <returns></returns>
public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx)
{
unsafe
{
var ptr = llama_get_embeddings_native(ctx);
if (ptr == null)
return Array.Empty<float>();

return new Span<float>(ptr, ctx.EmbeddingSize);
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx);
}
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_embeddings(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Apply chat template. Inspired by hf apply_chat_template() on python. /// Apply chat template. Inspired by hf apply_chat_template() on python.


Loading…
Cancel
Save