using LLama.Native; using System; using LLama.Exceptions; using LLama.Abstractions; using Microsoft.Extensions.Logging; using System.Threading; using System.Threading.Tasks; namespace LLama { /// /// The embedder for LLama, which supports getting embeddings from text. /// public sealed class LLamaEmbedder : IDisposable { /// /// Dimension of embedding vectors /// public int EmbeddingSize => Context.EmbeddingSize; /// /// LLama Context /// public LLamaContext Context { get; } /// /// Create a new embedder, using the given LLamaWeights /// /// /// /// public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { if (!@params.Embeddings) throw new ArgumentException("EmbeddingMode must be true", nameof(@params)); Context = weights.CreateContext(@params, logger); } /// /// Get the embeddings of the text. /// /// /// /// /// public Task GetEmbeddings(string text, CancellationToken cancellationToken = default) { return GetEmbeddings(text, true, cancellationToken); } /// /// Get the embeddings of the text. /// /// /// Add bos to the text. /// /// /// public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) { var tokens = Context.Tokenize(text, addBos); if (tokens.Length > Context.ContextSize) throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); // Evaluate prompt in batch-size chunks var n_past = 0; var batch = new LLamaBatch(); var batchSize = (int)Context.Params.BatchSize; for (var i = 0; i < tokens.Length; i += batchSize) { var n_eval = tokens.Length - i; if (n_eval > batchSize) n_eval = batchSize; batch.Clear(); batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); n_past += n_eval; var returnCode = await Context.DecodeAsync(batch, cancellationToken); if (returnCode != 0) throw new LLamaDecodeError(returnCode); } var embeddings = GetEmbeddingsArray(); // Remove everything we just evaluated from the context cache Context.NativeHandle.KvCacheClear(); // Normalize the embeddings vector // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 Normalize(embeddings); return embeddings; } private float[] GetEmbeddingsArray() { unsafe { var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); if (embeddings == null) embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); if (embeddings == null) return Array.Empty(); return new Span(embeddings, Context.EmbeddingSize).ToArray(); } } private static void Normalize(Span embeddings) { // Calculate length var lengthSqr = 0.0; foreach (var value in embeddings) lengthSqr += value * value; var length = (float)Math.Sqrt(lengthSqr); // Do not divide by length if it is zero if (length <= float.Epsilon) return; // Normalize for (var i = 0; i < embeddings.Length; i++) embeddings[i] /= length; } /// public void Dispose() { Context.Dispose(); } } }