diff --git a/LLama.Examples/GetEmbeddings.cs b/LLama.Examples/GetEmbeddings.cs new file mode 100644 index 00000000..e28330f0 --- /dev/null +++ b/LLama.Examples/GetEmbeddings.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace LLama.Examples +{ + public class GetEmbeddings + { + LLamaEmbedder _embedder; + public GetEmbeddings(string modelPath) + { + _embedder = new LLamaEmbedder(new LLamaParams(model: modelPath)); + } + + public void Run(string text) + { + Console.WriteLine(string.Join(", ", _embedder.GetEmbeddings(text))); + } + } +} diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 7bed702c..4a7bf587 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -2,11 +2,11 @@ using LLama.Examples; using LLama.Types; -int choice = 3; +int choice = 0; if(choice == 0) { - ChatSession chat = new(@"C:\Users\haipi\Source\repos\ggml-model-q4_0.bin", @"C:\Users\haipi\Source\repos\SciSharp\LLamaSharp\LLama.Examples\Assets\chat-with-bob.txt", new string[] { "User:" }); + ChatSession chat = new(@"", @"", new string[] { "User:" }); chat.Run(); } else if(choice == 1) @@ -22,6 +22,11 @@ else if(choice == 2) else if (choice == 3) // quantization { Quantize q = new Quantize(); - q.Run(@"D:\development\llama\weights\LLaMA\7B\ggml-model-f16.bin", - @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_1.bin", "q4_1"); + q.Run(@"", + @"", "q4_1"); +} +else if (choice == 4) // quantization +{ + GetEmbeddings em = new GetEmbeddings(@""); + em.Run("Hello, what is python?"); } \ No newline at end of file diff --git a/LLama/GptModel.cs b/LLama/GptModel.cs index d8e330e9..86985309 100644 --- a/LLama/GptModel.cs +++ b/LLama/GptModel.cs @@ -42,6 +42,7 @@ namespace LLama bool _first_time_chat = true; public string Name { get; set; } + public SafeLLamaContextHandle NativeHandle => _ctx; public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1, int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, @@ -70,7 +71,7 @@ namespace LLama _ctx = Utils.llama_init_from_gpt_params(ref _params); // Add a space in front of the first character to match OG llama tokenizer behavior - _params.prompt.Insert(0, " "); + _params.prompt = _params.prompt.Insert(0, " "); _session_tokens = new List(); _path_session = @params.path_session; @@ -223,7 +224,7 @@ namespace LLama _params.prompt = prompt; if (!_params.prompt.EndsWith(" ")) { - _params.prompt.Insert(0, " "); + _params.prompt = _params.prompt.Insert(0, " "); } _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); if (_embed_inp.Count > _n_ctx - 4) diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs new file mode 100644 index 00000000..93139027 --- /dev/null +++ b/LLama/LLamaEmbedder.cs @@ -0,0 +1,64 @@ +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Exceptions; + +namespace LLama +{ + public class LLamaEmbedder + { + SafeLLamaContextHandle _ctx; + + /// + /// Warning: must ensure the original model has params.embedding = true; + /// + /// + internal LLamaEmbedder(SafeLLamaContextHandle ctx) + { + _ctx = ctx; + } + + public LLamaEmbedder(LLamaParams @params) + { + @params.embedding = true; + _ctx = Utils.llama_init_from_gpt_params(ref @params); + } + + public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true) + { + if(n_thread == -1) + { + n_thread = Math.Max(Environment.ProcessorCount / 2, 1); + } + int n_past = 0; + if (add_bos) + { + text = text.Insert(0, " "); + } + var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos); + + // TODO(Rinne): deal with log of prompt + + if (embed_inp.Count > 0) + { + var embed_inp_array = embed_inp.ToArray(); + if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, n_thread) != 0) + { + throw new RuntimeError("Failed to eval."); + } + } + + int n_embed = NativeApi.llama_n_embd(_ctx); + var embeddings = NativeApi.llama_get_embeddings(_ctx); + if(embeddings == null) + { + return new float[0]; + } + var span = new Span(embeddings, n_embed); + float[] res = new float[n_embed]; + span.CopyTo(res.AsSpan()); + return res; + } + } +} diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 323ecc3d..5c26cb13 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -5,7 +5,7 @@ using System.Text; namespace LLama.Native { - internal class SafeLLamaContextHandle: SafeLLamaHandleBase + public class SafeLLamaContextHandle: SafeLLamaHandleBase { protected SafeLLamaContextHandle() { diff --git a/LLama/Native/SafeLLamaHandleBase.cs b/LLama/Native/SafeLLamaHandleBase.cs index a6febf0d..023f8cdd 100644 --- a/LLama/Native/SafeLLamaHandleBase.cs +++ b/LLama/Native/SafeLLamaHandleBase.cs @@ -5,7 +5,7 @@ using System.Text; namespace LLama.Native { - internal abstract class SafeLLamaHandleBase: SafeHandle + public abstract class SafeLLamaHandleBase: SafeHandle { private protected SafeLLamaHandleBase() : base(IntPtr.Zero, ownsHandle: true)