| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace LLama.Examples | |||||
| { | |||||
| public class GetEmbeddings | |||||
| { | |||||
| LLamaEmbedder _embedder; | |||||
| public GetEmbeddings(string modelPath) | |||||
| { | |||||
| _embedder = new LLamaEmbedder(new LLamaParams(model: modelPath)); | |||||
| } | |||||
| public void Run(string text) | |||||
| { | |||||
| Console.WriteLine(string.Join(", ", _embedder.GetEmbeddings(text))); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,11 +2,11 @@ | |||||
| using LLama.Examples; | using LLama.Examples; | ||||
| using LLama.Types; | using LLama.Types; | ||||
| int choice = 3; | |||||
| int choice = 0; | |||||
| if(choice == 0) | if(choice == 0) | ||||
| { | { | ||||
| ChatSession chat = new(@"C:\Users\haipi\Source\repos\ggml-model-q4_0.bin", @"C:\Users\haipi\Source\repos\SciSharp\LLamaSharp\LLama.Examples\Assets\chat-with-bob.txt", new string[] { "User:" }); | |||||
| ChatSession chat = new(@"<Your model file path>", @"<Your prompt file path>", new string[] { "User:" }); | |||||
| chat.Run(); | chat.Run(); | ||||
| } | } | ||||
| else if(choice == 1) | else if(choice == 1) | ||||
| @@ -22,6 +22,11 @@ else if(choice == 2) | |||||
| else if (choice == 3) // quantization | else if (choice == 3) // quantization | ||||
| { | { | ||||
| Quantize q = new Quantize(); | Quantize q = new Quantize(); | ||||
| q.Run(@"D:\development\llama\weights\LLaMA\7B\ggml-model-f16.bin", | |||||
| @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_1.bin", "q4_1"); | |||||
| q.Run(@"<Your src model file path>", | |||||
| @"<Your dst model file path>", "q4_1"); | |||||
| } | |||||
| else if (choice == 4) // quantization | |||||
| { | |||||
| GetEmbeddings em = new GetEmbeddings(@"<Your model file path>"); | |||||
| em.Run("Hello, what is python?"); | |||||
| } | } | ||||
| @@ -42,6 +42,7 @@ namespace LLama | |||||
| bool _first_time_chat = true; | bool _first_time_chat = true; | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public SafeLLamaContextHandle NativeHandle => _ctx; | |||||
| public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1, | public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1, | ||||
| int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, | int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, | ||||
| @@ -70,7 +71,7 @@ namespace LLama | |||||
| _ctx = Utils.llama_init_from_gpt_params(ref _params); | _ctx = Utils.llama_init_from_gpt_params(ref _params); | ||||
| // Add a space in front of the first character to match OG llama tokenizer behavior | // Add a space in front of the first character to match OG llama tokenizer behavior | ||||
| _params.prompt.Insert(0, " "); | |||||
| _params.prompt = _params.prompt.Insert(0, " "); | |||||
| _session_tokens = new List<llama_token>(); | _session_tokens = new List<llama_token>(); | ||||
| _path_session = @params.path_session; | _path_session = @params.path_session; | ||||
| @@ -223,7 +224,7 @@ namespace LLama | |||||
| _params.prompt = prompt; | _params.prompt = prompt; | ||||
| if (!_params.prompt.EndsWith(" ")) | if (!_params.prompt.EndsWith(" ")) | ||||
| { | { | ||||
| _params.prompt.Insert(0, " "); | |||||
| _params.prompt = _params.prompt.Insert(0, " "); | |||||
| } | } | ||||
| _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); | _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); | ||||
| if (_embed_inp.Count > _n_ctx - 4) | if (_embed_inp.Count > _n_ctx - 4) | ||||
| @@ -0,0 +1,64 @@ | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| namespace LLama | |||||
| { | |||||
| public class LLamaEmbedder | |||||
| { | |||||
| SafeLLamaContextHandle _ctx; | |||||
| /// <summary> | |||||
| /// Warning: must ensure the original model has params.embedding = true; | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| internal LLamaEmbedder(SafeLLamaContextHandle ctx) | |||||
| { | |||||
| _ctx = ctx; | |||||
| } | |||||
| public LLamaEmbedder(LLamaParams @params) | |||||
| { | |||||
| @params.embedding = true; | |||||
| _ctx = Utils.llama_init_from_gpt_params(ref @params); | |||||
| } | |||||
| public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true) | |||||
| { | |||||
| if(n_thread == -1) | |||||
| { | |||||
| n_thread = Math.Max(Environment.ProcessorCount / 2, 1); | |||||
| } | |||||
| int n_past = 0; | |||||
| if (add_bos) | |||||
| { | |||||
| text = text.Insert(0, " "); | |||||
| } | |||||
| var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos); | |||||
| // TODO(Rinne): deal with log of prompt | |||||
| if (embed_inp.Count > 0) | |||||
| { | |||||
| var embed_inp_array = embed_inp.ToArray(); | |||||
| if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, n_thread) != 0) | |||||
| { | |||||
| throw new RuntimeError("Failed to eval."); | |||||
| } | |||||
| } | |||||
| int n_embed = NativeApi.llama_n_embd(_ctx); | |||||
| var embeddings = NativeApi.llama_get_embeddings(_ctx); | |||||
| if(embeddings == null) | |||||
| { | |||||
| return new float[0]; | |||||
| } | |||||
| var span = new Span<float>(embeddings, n_embed); | |||||
| float[] res = new float[n_embed]; | |||||
| span.CopyTo(res.AsSpan()); | |||||
| return res; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -5,7 +5,7 @@ using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| internal class SafeLLamaContextHandle: SafeLLamaHandleBase | |||||
| public class SafeLLamaContextHandle: SafeLLamaHandleBase | |||||
| { | { | ||||
| protected SafeLLamaContextHandle() | protected SafeLLamaContextHandle() | ||||
| { | { | ||||
| @@ -5,7 +5,7 @@ using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| internal abstract class SafeLLamaHandleBase: SafeHandle | |||||
| public abstract class SafeLLamaHandleBase: SafeHandle | |||||
| { | { | ||||
| private protected SafeLLamaHandleBase() | private protected SafeLLamaHandleBase() | ||||
| : base(IntPtr.Zero, ownsHandle: true) | : base(IntPtr.Zero, ownsHandle: true) | ||||