using LLama.Native;
using System;
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
///
/// The embedder for LLama, which supports getting embeddings from text.
///
public sealed class LLamaEmbedder
: IDisposable
{
///
/// Dimension of embedding vectors
///
public int EmbeddingSize => Context.EmbeddingSize;
///
/// LLama Context
///
public LLamaContext Context { get; }
///
/// Create a new embedder, using the given LLamaWeights
///
///
///
///
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if (!@params.Embeddings)
throw new ArgumentException("EmbeddingMode must be true", nameof(@params));
Context = weights.CreateContext(@params, logger);
}
///
/// Get the embeddings of the text.
///
///
///
///
///
public Task GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
return GetEmbeddings(text, true, cancellationToken);
}
///
/// Get the embeddings of the text.
///
///
/// Add bos to the text.
///
///
///
public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
var tokens = Context.Tokenize(text, addBos);
if (tokens.Length > Context.ContextSize)
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));
// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
if (n_eval > batchSize)
n_eval = batchSize;
batch.Clear();
batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true);
n_past += n_eval;
var returnCode = await Context.DecodeAsync(batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
var embeddings = GetEmbeddingsArray();
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();
// Normalize the embeddings vector
// https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92
Normalize(embeddings);
return embeddings;
}
private float[] GetEmbeddingsArray()
{
unsafe
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
if (embeddings == null)
return Array.Empty();
return new Span(embeddings, Context.EmbeddingSize).ToArray();
}
}
private static void Normalize(Span embeddings)
{
// Calculate length
var lengthSqr = 0.0;
foreach (var value in embeddings)
lengthSqr += value * value;
var length = (float)Math.Sqrt(lengthSqr);
// Do not divide by length if it is zero
if (length <= float.Epsilon)
return;
// Normalize
for (var i = 0; i < embeddings.Length; i++)
embeddings[i] /= length;
}
///
public void Dispose()
{
Context.Dispose();
}
}
}