using LLama.Native; using System; using System.Collections.Generic; using System.Text; using LLama.Exceptions; using System.Linq; using LLama.Abstractions; namespace LLama { /// /// The embedder for LLama, which supports getting embeddings from text. /// public class LLamaEmbedder : IDisposable { SafeLLamaContextHandle _ctx; /// /// Warning: must ensure the original model has params.embedding = true; /// /// internal LLamaEmbedder(SafeLLamaContextHandle ctx) { _ctx = ctx; } /// /// /// /// public LLamaEmbedder(IModelParams @params) { @params.EmbeddingMode = true; _ctx = Utils.InitLLamaContextFromModelParams(@params); } /// /// Get the embeddings of the text. /// /// /// Threads used for inference. /// Add bos to the text. /// /// /// public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") { if (threads == -1) { threads = Math.Max(Environment.ProcessorCount / 2, 1); } int n_past = 0; if (addBos) { text = text.Insert(0, " "); } var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray(); // TODO(Rinne): deal with log of prompt if (embed_inp_array.Length > 0) { if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) { throw new RuntimeError("Failed to eval."); } } int n_embed = NativeApi.llama_n_embd(_ctx); var embeddings = NativeApi.llama_get_embeddings(_ctx); if (embeddings == null) { return Array.Empty(); } var span = new Span(embeddings, n_embed); float[] res = new float[n_embed]; span.CopyTo(res.AsSpan()); return res; } /// /// /// public void Dispose() { _ctx.Dispose(); } } }