| @@ -1,16 +1,29 @@ | |||||
| using LLama; | using LLama; | ||||
| using LLama.Types; | using LLama.Types; | ||||
| //string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_0.bin"; | |||||
| //LLamaModel model = new(modelPath, logits_all: false, verbose: false, n_ctx: 2048); | |||||
| //List<ChatCompletionMessage> chats = new List<ChatCompletionMessage>(); | |||||
| //chats.Add(new ChatCompletionMessage("user", "Hi, Alice, I'm Rinne.", null)); | |||||
| //chats.Add(new ChatCompletionMessage("assistant", "Hi, Rinne, I'm Alice. What can I do for you?", null)); | |||||
| //Console.Write("You: "); | |||||
| //var question = "This is a text classification task, below are the category list:\r\n1. Air Handler\r\n2. Tub/Shower\r\n3. Fireplace\r\n4. Bathroom\r\n5. Kitchen\r\n6. Powerwash roof eves and soffits\r\n\r\nFor example:\r\n1. \"Clear drain clog at kitchen sink\": Kitchen\r\n2. \"Change blower motor speed\": Air Handler\r\n3. \"Clear drain clog at tub/shower\": Bathroom\r\n4. \"Clear drain clog at toilet\": Bathroom\r\n\r\nPlease classify this text \"toilet clogged\" in provided list. output in json format: {\"category\": \"\", \"confidence\":0.0}"; | |||||
| //chats.Add(new ChatCompletionMessage("user", question, null)); | |||||
| //var output = model.CreateChatCompletion(chats, max_tokens: 1024); | |||||
| //Console.WriteLine($"LLama AI: {output.Choices[0].Message.Content}"); | |||||
| string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_0.bin"; | string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_0.bin"; | ||||
| LLamaModel model = new(modelPath, logits_all: false, verbose: false); | |||||
| List<ChatCompletionMessage> chats = new List<ChatCompletionMessage>(); | |||||
| chats.Add(new ChatCompletionMessage("user", "Hi, Alice, I'm Rinne.", null)); | |||||
| chats.Add(new ChatCompletionMessage("assistant", "Hi, Rinne, I'm Alice. What can I do for you?", null)); | |||||
| GptModel model = new(new GptParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: new List<string>(){"User:"}, | |||||
| repeat_penalty: 1.0f)); | |||||
| model = model.WithPrompt("Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:"); | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| Console.Write("You: "); | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| var question = Console.ReadLine(); | var question = Console.ReadLine(); | ||||
| chats.Add(new ChatCompletionMessage("user", question, null)); | |||||
| var output = model.CreateChatCompletion(chats, max_tokens: 256); | |||||
| Console.WriteLine($"LLama AI: {output.Choices[0].Message.Content}"); | |||||
| Console.ForegroundColor = ConsoleColor.White; | |||||
| var outputs = model.Call(question); | |||||
| foreach (var output in outputs) | |||||
| { | |||||
| Console.Write(output); | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,134 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace LLama | |||||
| { | |||||
| using llama_token = Int32; | |||||
| public struct GptParams | |||||
| { | |||||
| public int seed; // RNG seed | |||||
| public int n_threads = Math.Max(Environment.ProcessorCount / 2, 1); // number of threads (-1 = autodetect) | |||||
| public int n_predict = -1; // new tokens to predict | |||||
| public int n_parts = -1; // amount of model parts (-1 = determine from model dimensions) | |||||
| public int n_ctx = 512; // context size | |||||
| public int n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) | |||||
| public int n_keep = 0; // number of tokens to keep from initial prompt | |||||
| // sampling parameters | |||||
| public Dictionary<llama_token, float> logit_bias; // logit bias for specific tokens | |||||
| public int top_k = 40; // <= 0 to use vocab size | |||||
| public float top_p = 0.95f; // 1.0 = disabled | |||||
| public float tfs_z = 1.00f; // 1.0 = disabled | |||||
| public float typical_p = 1.00f; // 1.0 = disabled | |||||
| public float temp = 0.80f; // 1.0 = disabled | |||||
| public float repeat_penalty = 1.10f; // 1.0 = disabled | |||||
| public int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) | |||||
| public float frequency_penalty = 0.00f; // 0.0 = disabled | |||||
| public float presence_penalty = 0.00f; // 0.0 = disabled | |||||
| public int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | |||||
| public float mirostat_tau = 5.00f; // target entropy | |||||
| public float mirostat_eta = 0.10f; // learning rate | |||||
| public string model = "models/lamma-7B/ggml-model.bin"; // model path | |||||
| public string prompt = ""; // initial prompt (set to empty string for interactive mode) | |||||
| public string path_session = ""; // path to file for saving/loading model eval state | |||||
| public string input_prefix = ""; // string to prefix user inputs with | |||||
| public string input_suffix = ""; // string to suffix user inputs with | |||||
| public List<string> antiprompt; // string upon seeing which more user input is prompted | |||||
| public string lora_adapter = ""; // lora adapter path | |||||
| public string lora_base = ""; // base model path for the lora adapter | |||||
| public bool memory_f16 = true; // use f16 instead of f32 for memory kv | |||||
| public bool random_prompt = false; // randomize prompt if none provided | |||||
| public bool use_color = false; // use color to distinguish generations and inputs | |||||
| public bool interactive = false; // interactive mode | |||||
| public bool embedding = false; // get only sentence embedding | |||||
| public bool interactive_first = false; // wait for user input immediately | |||||
| public bool instruct = false; // instruction mode (used for Alpaca models) | |||||
| public bool penalize_nl = true; // consider newlines as a repeatable token | |||||
| public bool perplexity = false; // compute perplexity over the prompt | |||||
| public bool use_mmap = true; // use mmap for faster loads | |||||
| public bool use_mlock = false; // use mlock to keep model in memory | |||||
| public bool mem_test = false; // compute maximum memory usage | |||||
| public bool verbose_prompt = false; // print prompt tokens before generation | |||||
| public GptParams(int seed = 0, int n_threads = -1, int n_predict = -1, | |||||
| int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, | |||||
| Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||||
| float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, | |||||
| int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, | |||||
| int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, | |||||
| string model = "models/lamma-7B/ggml-model.bin", string prompt = "", | |||||
| string path_session = "", string input_prefix = "", string input_suffix = "", | |||||
| List<string> antiprompt = null, string lora_adapter = "", string lora_base = "", | |||||
| bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false, | |||||
| bool embedding = false, bool interactive_first = false, bool instruct = false, bool penalize_nl = true, | |||||
| bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false, | |||||
| bool verbose_prompt = false) | |||||
| { | |||||
| this.seed = seed; | |||||
| if(n_threads != -1) | |||||
| { | |||||
| this.n_threads = n_threads; | |||||
| } | |||||
| this.n_predict = n_predict; | |||||
| this.n_parts = n_parts; | |||||
| this.n_ctx = n_ctx; | |||||
| this.n_batch = n_batch; | |||||
| this.n_keep = n_keep; | |||||
| if (logit_bias == null) | |||||
| { | |||||
| logit_bias = new Dictionary<llama_token, float>(); | |||||
| } | |||||
| this.logit_bias = logit_bias; | |||||
| this.top_k = top_k; | |||||
| this.top_p = top_p; | |||||
| this.tfs_z = tfs_z; | |||||
| this.typical_p = typical_p; | |||||
| this.temp = temp; | |||||
| this.repeat_penalty = repeat_penalty; | |||||
| this.repeat_last_n = repeat_last_n; | |||||
| this.frequency_penalty = frequency_penalty; | |||||
| this.presence_penalty = presence_penalty; | |||||
| this.mirostat = mirostat; | |||||
| this.mirostat_tau = mirostat_tau; | |||||
| this.mirostat_eta = mirostat_eta; | |||||
| this.model = model; | |||||
| this.prompt = prompt; | |||||
| this.path_session = path_session; | |||||
| this.input_prefix = input_prefix; | |||||
| this.input_suffix = input_suffix; | |||||
| if (antiprompt == null) | |||||
| { | |||||
| antiprompt = new List<string>(); | |||||
| } | |||||
| this.antiprompt = antiprompt; | |||||
| this.lora_adapter = lora_adapter; | |||||
| this.lora_base = lora_base; | |||||
| this.memory_f16 = memory_f16; | |||||
| this.random_prompt = random_prompt; | |||||
| this.use_color = use_color; | |||||
| this.interactive = interactive; | |||||
| this.embedding = embedding; | |||||
| this.interactive_first = interactive_first; | |||||
| this.instruct = instruct; | |||||
| this.penalize_nl = penalize_nl; | |||||
| this.perplexity = perplexity; | |||||
| this.use_mmap = use_mmap; | |||||
| this.use_mlock = use_mlock; | |||||
| this.mem_test = mem_test; | |||||
| this.verbose_prompt = verbose_prompt; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -608,7 +608,7 @@ namespace LLama | |||||
| yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] | yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] | ||||
| { | { | ||||
| new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) | new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) | ||||
| }); | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,547 @@ | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using System.Linq; | |||||
| using System.Text.RegularExpressions; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Diagnostics; | |||||
| namespace LLama | |||||
| { | |||||
| using llama_token = Int32; | |||||
| public class GptModel | |||||
| { | |||||
| GptParams _params; | |||||
| SafeLLamaContextHandle _ctx; | |||||
| string _path_session; | |||||
| List<llama_token> _session_tokens; | |||||
| List<llama_token> _embed_inp; | |||||
| int _n_ctx; | |||||
| List<llama_token> _inp_pfx; | |||||
| List<llama_token> _inp_sfx; | |||||
| List<llama_token> _llama_token_newline; | |||||
| List<llama_token> _last_n_tokens; | |||||
| bool _is_interacting; | |||||
| bool _is_antiprompt; | |||||
| bool _input_echo; | |||||
| // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session | |||||
| // if we loaded a session with at least 75% similarity. It's currently just used to speed up the | |||||
| // initial prompt so it doesn't need to be an exact match. | |||||
| bool _need_to_save_session; | |||||
| int _n_past; | |||||
| int _n_remain; | |||||
| int _n_consumed; | |||||
| int _n_session_consumed; | |||||
| List<llama_token> _embed; | |||||
| public GptModel(string model_path = "models/lamma-7B/ggml-model.bin", int seed = 0, int n_threads = -1, int n_predict = -1, | |||||
| int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, | |||||
| Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f, | |||||
| float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f, | |||||
| int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f, | |||||
| int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, string prompt = "", | |||||
| string path_session = "", string input_prefix = "", string input_suffix = "", | |||||
| List<string> antiprompt = null, string lora_adapter = "", string lora_base = "", | |||||
| bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false, | |||||
| bool embedding = false, bool interactive_first = false, bool instruct = false, bool penalize_nl = true, | |||||
| bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false, | |||||
| bool verbose_prompt = false) | |||||
| { | |||||
| } | |||||
| public GptModel WithPrompt(string prompt) | |||||
| { | |||||
| _params.prompt = prompt; | |||||
| if(!_params.prompt.EndsWith(" ")) | |||||
| { | |||||
| _params.prompt.Insert(0, " "); | |||||
| } | |||||
| _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); | |||||
| if (_embed_inp.Count > _n_ctx - 4) | |||||
| { | |||||
| throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})"); | |||||
| } | |||||
| return this; | |||||
| } | |||||
| public GptModel WithPromptFile(string promptFileName) | |||||
| { | |||||
| return WithPrompt(File.ReadAllText(promptFileName)); | |||||
| } | |||||
| public unsafe GptModel(GptParams @params) | |||||
| { | |||||
| _params = @params; | |||||
| _ctx = Utils.llama_init_from_gpt_params(ref _params); | |||||
| // Add a space in front of the first character to match OG llama tokenizer behavior | |||||
| _params.prompt.Insert(0, " "); | |||||
| _session_tokens = new List<llama_token>(); | |||||
| _path_session = @params.path_session; | |||||
| if (!string.IsNullOrEmpty(_path_session)) | |||||
| { | |||||
| Logger.Default.Info($"Attempting to load saved session from '{_path_session}'"); | |||||
| if (!File.Exists(_path_session)) | |||||
| { | |||||
| Logger.Default.Warn("Session file does not exist, will create."); | |||||
| } | |||||
| llama_token[] session_tokens = new llama_token[@params.n_ctx]; | |||||
| ulong n_token_count_out = 0; | |||||
| if (!NativeApi.llama_load_session_file(_ctx, _path_session, session_tokens, (ulong)@params.n_ctx, &n_token_count_out)) | |||||
| { | |||||
| throw new RuntimeError($"Failed to load session file {_path_session}"); | |||||
| } | |||||
| _session_tokens = session_tokens.Take((int)n_token_count_out).ToList(); | |||||
| Logger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens"); | |||||
| } | |||||
| _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); | |||||
| _n_ctx = NativeApi.llama_n_ctx(_ctx); | |||||
| if (_embed_inp.Count > _n_ctx - 4) | |||||
| { | |||||
| throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})"); | |||||
| } | |||||
| ulong n_matching_session_tokens = 0; | |||||
| if (_session_tokens.Count > 0) | |||||
| { | |||||
| foreach (var id in _session_tokens) | |||||
| { | |||||
| if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens]) | |||||
| { | |||||
| break; | |||||
| } | |||||
| n_matching_session_tokens++; | |||||
| } | |||||
| if (n_matching_session_tokens >= (ulong)_embed_inp.Count) | |||||
| { | |||||
| Logger.Default.Info("Session file has exact match for prompt!"); | |||||
| } | |||||
| else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2)) | |||||
| { | |||||
| Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " + | |||||
| $"/ {_embed_inp.Count} tokens); will mostly be reevaluated."); | |||||
| } | |||||
| else | |||||
| { | |||||
| Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " + | |||||
| $"tokens of prompt."); | |||||
| } | |||||
| } | |||||
| // number of tokens to keep when resetting context | |||||
| if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct) | |||||
| { | |||||
| _params.n_keep = _embed_inp.Count; | |||||
| } | |||||
| // prefix & suffix for instruct mode | |||||
| _inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true); | |||||
| _inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false); | |||||
| // in instruct mode, we inject a prefix and a suffix to each input by the user | |||||
| if (_params.instruct) | |||||
| { | |||||
| _params.interactive_first = true; | |||||
| _params.antiprompt.Add("### Instruction:\n\n"); | |||||
| } | |||||
| // enable interactive mode if reverse prompt or interactive start is specified | |||||
| if (_params.antiprompt.Count != 0 || _params.interactive_first) | |||||
| { | |||||
| _params.interactive = true; | |||||
| } | |||||
| // determine newline token | |||||
| _llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false); | |||||
| if (_params.verbose_prompt) | |||||
| { | |||||
| Logger.Default.Info("\n"); | |||||
| Logger.Default.Info($"prompt: '{_params.prompt}'"); | |||||
| Logger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}"); | |||||
| for (int i = 0; i < _embed_inp.Count; i++) | |||||
| { | |||||
| Logger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'"); | |||||
| } | |||||
| if (_params.n_keep > 0) | |||||
| { | |||||
| Logger.Default.Info($"static prompt based on n_keep: '"); | |||||
| for (int i = 0; i < _params.n_keep; i++) | |||||
| { | |||||
| Logger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}"); | |||||
| } | |||||
| Logger.Default.Info("\n"); | |||||
| } | |||||
| Logger.Default.Info("\n"); | |||||
| } | |||||
| if (_params.interactive) | |||||
| { | |||||
| Logger.Default.Info("interactive mode on."); | |||||
| } | |||||
| Logger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " + | |||||
| $"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " + | |||||
| $"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," + | |||||
| $" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," + | |||||
| $" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}"); | |||||
| Logger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " + | |||||
| $"n_keep = {_params.n_keep}"); | |||||
| Logger.Default.Info("\n"); | |||||
| _last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList(); | |||||
| if (_params.interactive) | |||||
| { | |||||
| Logger.Default.Info("== Running in interactive mode. =="); | |||||
| _is_interacting = _params.interactive_first; | |||||
| } | |||||
| _is_antiprompt = false; | |||||
| _input_echo = true; | |||||
| _need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4); | |||||
| _n_past = 0; | |||||
| _n_remain = _params.n_predict; | |||||
| _n_consumed = 0; | |||||
| _n_session_consumed = 0; | |||||
| _embed = new List<llama_token>(); | |||||
| } | |||||
| private string ProcessTextBeforeInfer(string text) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(_params.input_prefix)) | |||||
| { | |||||
| text = _params.input_prefix + text; | |||||
| } | |||||
| if (!text.EndsWith("\n")) | |||||
| { | |||||
| text += "\n"; | |||||
| } | |||||
| if (text.Length > 1) | |||||
| { | |||||
| // append input suffix if any | |||||
| if (!string.IsNullOrEmpty(_params.input_suffix)) | |||||
| { | |||||
| text += _params.input_suffix; | |||||
| Console.Write(_params.input_suffix); | |||||
| } | |||||
| // instruct mode: insert instruction prefix | |||||
| if (_params.instruct && !_is_antiprompt) | |||||
| { | |||||
| _n_consumed = _embed_inp.Count; | |||||
| _embed_inp.AddRange(_inp_pfx); | |||||
| } | |||||
| var line_inp = Utils.llama_tokenize(_ctx, text, false); | |||||
| _embed_inp.AddRange(line_inp); | |||||
| // instruct mode: insert response suffix | |||||
| if (_params.instruct) | |||||
| { | |||||
| _embed_inp.AddRange(_inp_sfx); | |||||
| } | |||||
| _n_remain -= line_inp.Count; | |||||
| } | |||||
| return text; | |||||
| } | |||||
| public IEnumerable<string> Call(string text) | |||||
| { | |||||
| _is_interacting = _is_antiprompt = false; | |||||
| ProcessTextBeforeInfer(text); | |||||
| while ((_n_remain != 0 || _params.interactive) && !_is_interacting) | |||||
| { | |||||
| if (_embed.Count > 0) | |||||
| { | |||||
| // infinite text generation via context swapping | |||||
| // if we run out of context: | |||||
| // - take the n_keep first tokens from the original prompt (via n_past) | |||||
| // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches | |||||
| if (_n_past + _embed.Count > _n_ctx) | |||||
| { | |||||
| int n_left = _n_past - _params.n_keep; | |||||
| _n_past = _params.n_keep; | |||||
| // insert n_left/2 tokens at the start of embed from last_n_tokens | |||||
| _embed.InsertRange(0, _last_n_tokens.GetRange(_n_ctx - n_left / 2 - _embed.Count, _embed.Count)); | |||||
| // stop saving session if we run out of context | |||||
| _path_session = ""; | |||||
| // Console.WriteLine("\n---\n"); | |||||
| // Console.Write("resetting: '"); | |||||
| // for (int i = 0; i < embed.Count; i++) { | |||||
| // Console.Write(llama_token_to_str(ctx, embed[i])); | |||||
| // } | |||||
| // Console.WriteLine("'\n"); | |||||
| // Console.WriteLine("\n---\n"); | |||||
| } | |||||
| // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) | |||||
| // REVIEW | |||||
| if (_n_session_consumed < _session_tokens.Count) | |||||
| { | |||||
| int i = 0; | |||||
| for (; i < _embed.Count; i++) | |||||
| { | |||||
| if (!_embed[i].Equals(_session_tokens[_n_session_consumed])) | |||||
| { | |||||
| _session_tokens.RemoveRange(_n_session_consumed, _session_tokens.Count - _n_session_consumed); | |||||
| break; | |||||
| } | |||||
| _n_past++; | |||||
| _n_session_consumed++; | |||||
| if (_n_session_consumed >= _session_tokens.Count) | |||||
| { | |||||
| i++; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (i > 0) | |||||
| { | |||||
| _embed.RemoveRange(0, i); | |||||
| } | |||||
| } | |||||
| // evaluate tokens in batches | |||||
| // embed is typically prepared beforehand to fit within a batch, but not always | |||||
| for (int i = 0; i < _embed.Count; i += _params.n_batch) | |||||
| { | |||||
| int n_eval = _embed.Count - i; | |||||
| if (n_eval > _params.n_batch) | |||||
| { | |||||
| n_eval = _params.n_batch; | |||||
| } | |||||
| var array = _embed.GetRange(i, n_eval).ToArray(); | |||||
| if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0) | |||||
| { | |||||
| Logger.Default.Error($"Failed to eval"); | |||||
| throw new RuntimeError("Failed to eval"); | |||||
| } | |||||
| _n_past += n_eval; | |||||
| } | |||||
| if (_embed.Count > 0 && !string.IsNullOrEmpty(_path_session)) | |||||
| { | |||||
| _session_tokens.AddRange(_embed); | |||||
| _n_session_consumed = _session_tokens.Count; | |||||
| } | |||||
| } | |||||
| _embed.Clear(); | |||||
| if (_embed_inp.Count <= _n_consumed && !_is_interacting) | |||||
| { | |||||
| var temp = _params.temp; | |||||
| var top_k = _params.top_k <= 0 ? NativeApi.llama_n_vocab(_ctx) : _params.top_k; | |||||
| var top_p = _params.top_p; | |||||
| var tfs_z = _params.tfs_z; | |||||
| var typical_p = _params.typical_p; | |||||
| var repeat_last_n = _params.repeat_last_n < 0 ? _n_ctx : _params.repeat_last_n; | |||||
| var repeat_penalty = _params.repeat_penalty; | |||||
| var alpha_presence = _params.presence_penalty; | |||||
| var alpha_frequency = _params.frequency_penalty; | |||||
| var mirostat = _params.mirostat; | |||||
| var mirostat_tau = _params.mirostat_tau; | |||||
| var mirostat_eta = _params.mirostat_eta; | |||||
| var penalize_nl = _params.penalize_nl; | |||||
| // optionally save the session on first sample (for faster prompt loading next time) | |||||
| if (!string.IsNullOrEmpty(_path_session) && _need_to_save_session) | |||||
| { | |||||
| _need_to_save_session = false; | |||||
| NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count); | |||||
| } | |||||
| llama_token id = 0; | |||||
| { | |||||
| var n_vocab = NativeApi.llama_n_vocab(_ctx); | |||||
| var logits = Utils.llama_get_logits(_ctx, n_vocab); | |||||
| // Apply params.logit_bias map | |||||
| foreach (KeyValuePair<int, float> it in _params.logit_bias) | |||||
| { | |||||
| logits[it.Key] += it.Value; | |||||
| } | |||||
| var candidates = new List<LLamaTokenData>(); | |||||
| candidates.Capacity = n_vocab; | |||||
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) | |||||
| { | |||||
| candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); | |||||
| } | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); | |||||
| // Apply penalties | |||||
| float nl_logit = logits[NativeApi.llama_token_nl()]; | |||||
| var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx); | |||||
| SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, | |||||
| _last_n_tokens.GetRange(_last_n_tokens.Count - last_n_repeat, last_n_repeat).ToArray(), | |||||
| (ulong)last_n_repeat, repeat_penalty); | |||||
| SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, | |||||
| _last_n_tokens.GetRange(_last_n_tokens.Count - last_n_repeat, last_n_repeat).ToArray(), | |||||
| (ulong)last_n_repeat, alpha_frequency, alpha_presence); | |||||
| if (!penalize_nl) | |||||
| { | |||||
| logits[NativeApi.llama_token_nl()] = nl_logit; | |||||
| } | |||||
| if (temp <= 0) | |||||
| { | |||||
| // Greedy sampling | |||||
| id = SamplingApi.llama_sample_token_greedy(_ctx, candidates_p); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (mirostat == 1) | |||||
| { | |||||
| float mirostat_mu = 2.0f * mirostat_tau; | |||||
| const int mirostat_m = 100; | |||||
| SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp); | |||||
| id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu); | |||||
| } | |||||
| else if (mirostat == 2) | |||||
| { | |||||
| float mirostat_mu = 2.0f * mirostat_tau; | |||||
| SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp); | |||||
| id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_mu); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Temperature sampling | |||||
| SamplingApi.llama_sample_top_k(_ctx, candidates_p, top_k, 1); | |||||
| SamplingApi.llama_sample_tail_free(_ctx, candidates_p, tfs_z, 1); | |||||
| SamplingApi.llama_sample_typical(_ctx, candidates_p, typical_p, 1); | |||||
| SamplingApi.llama_sample_top_p(_ctx, candidates_p, top_p, 1); | |||||
| SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp); | |||||
| id = SamplingApi.llama_sample_token(_ctx, candidates_p); | |||||
| } | |||||
| } | |||||
| _last_n_tokens.RemoveAt(0); | |||||
| _last_n_tokens.Add(id); | |||||
| } | |||||
| // replace end of text token with newline token when in interactive mode | |||||
| if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct) | |||||
| { | |||||
| id = _llama_token_newline[0]; | |||||
| if (_params.antiprompt.Count != 0) | |||||
| { | |||||
| // tokenize and inject first reverse prompt | |||||
| var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false); | |||||
| _embed_inp.AddRange(first_antiprompt); | |||||
| } | |||||
| } | |||||
| // add it to the context | |||||
| _embed.Add(id); | |||||
| // echo this to console | |||||
| _input_echo = true; | |||||
| // decrement remaining sampling budget | |||||
| _n_remain--; | |||||
| } | |||||
| else | |||||
| { | |||||
| // Assuming that the necessary variables have been defined and initialized, | |||||
| // the C# equivalent code could be: | |||||
| while (_embed_inp.Count > _n_consumed) | |||||
| { | |||||
| _embed.Add(_embed_inp[_n_consumed]); | |||||
| _last_n_tokens.RemoveAt(0); | |||||
| _last_n_tokens.Add(_embed_inp[_n_consumed]); | |||||
| _n_consumed++; | |||||
| if (_embed.Count >= _params.n_batch) | |||||
| { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (_input_echo) | |||||
| { | |||||
| foreach (var id in _embed) | |||||
| { | |||||
| #if NET6_0_OR_GREATER | |||||
| yield return Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id)); | |||||
| #else | |||||
| yield return Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id)); | |||||
| #endif | |||||
| } | |||||
| } | |||||
| if (_params.interactive && _embed_inp.Count <= _n_consumed) | |||||
| { | |||||
| if (_params.antiprompt.Count > 0) | |||||
| { | |||||
| string last_output = ""; | |||||
| foreach (var id in _last_n_tokens) | |||||
| { | |||||
| #if NET6_0_OR_GREATER | |||||
| last_output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id)); | |||||
| #else | |||||
| last_output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id)); | |||||
| #endif | |||||
| } | |||||
| _is_antiprompt = false; | |||||
| foreach (var antiprompt in _params.antiprompt) | |||||
| { | |||||
| if (last_output.EndsWith(antiprompt)) | |||||
| { | |||||
| _is_interacting = true; | |||||
| _is_antiprompt = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if(_n_past > 0 && _is_interacting) | |||||
| { | |||||
| _input_echo = false; | |||||
| break; | |||||
| } | |||||
| if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos()) | |||||
| { | |||||
| if (_params.instruct) { | |||||
| _is_interacting = true; | |||||
| } else | |||||
| { | |||||
| Logger.Default.Info(" [end of text]"); | |||||
| } | |||||
| } | |||||
| if (_params.interactive && _n_remain <= 0 && _params.n_predict != -1) { | |||||
| _n_remain = _params.n_predict; | |||||
| _is_interacting = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,7 +6,7 @@ using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| internal partial class NativeApi | |||||
| internal unsafe partial class NativeApi | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
| @@ -93,7 +93,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName)] | [DllImport(libraryName)] | ||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float[] mu); | |||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | ||||
| @@ -105,7 +105,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName)] | [DllImport(libraryName)] | ||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float[] mu); | |||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Selects the token with the highest probability. | /// Selects the token with the highest probability. | ||||
| @@ -120,7 +120,7 @@ namespace LLama.Native | |||||
| /// <param name="n_token_count_out"></param> | /// <param name="n_token_count_out"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName)] | [DllImport(libraryName)] | ||||
| public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong[] n_token_count_out); | |||||
| public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); | |||||
| /// <summary> | /// <summary> | ||||
| /// Save session file | /// Save session file | ||||
| @@ -148,14 +148,19 @@ namespace LLama.Native | |||||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | ||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, float[] mu) | |||||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, in float mu) | |||||
| { | { | ||||
| var handle = candidates.data.Pin(); | var handle = candidates.data.Pin(); | ||||
| var st = new LLamaTokenDataArrayNative(); | var st = new LLamaTokenDataArrayNative(); | ||||
| st.data = new IntPtr(handle.Pointer); | st.data = new IntPtr(handle.Pointer); | ||||
| st.size = candidates.size; | st.size = candidates.size; | ||||
| st.sorted = candidates.sorted; | st.sorted = candidates.sorted; | ||||
| return NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, mu); | |||||
| llama_token res; | |||||
| fixed(float* pmu = &mu) | |||||
| { | |||||
| res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu); | |||||
| } | |||||
| return res; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -167,14 +172,19 @@ namespace LLama.Native | |||||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | ||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, float[] mu) | |||||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, in float mu) | |||||
| { | { | ||||
| var handle = candidates.data.Pin(); | var handle = candidates.data.Pin(); | ||||
| var st = new LLamaTokenDataArrayNative(); | var st = new LLamaTokenDataArrayNative(); | ||||
| st.data = new IntPtr(handle.Pointer); | st.data = new IntPtr(handle.Pointer); | ||||
| st.size = candidates.size; | st.size = candidates.size; | ||||
| st.sorted = candidates.sorted; | st.sorted = candidates.sorted; | ||||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, mu); | |||||
| llama_token res; | |||||
| fixed (float* pmu = &mu) | |||||
| { | |||||
| res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu); | |||||
| } | |||||
| return res; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,62 @@ | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| namespace LLama | |||||
| { | |||||
| using llama_token = Int32; | |||||
| internal static class Utils | |||||
| { | |||||
| public static SafeLLamaContextHandle llama_init_from_gpt_params(ref GptParams @params) | |||||
| { | |||||
| var lparams = NativeApi.llama_context_default_params(); | |||||
| lparams.n_ctx = @params.n_ctx; | |||||
| lparams.n_parts = @params.n_parts; | |||||
| lparams.seed = @params.seed; | |||||
| lparams.f16_kv = @params.memory_f16; | |||||
| lparams.use_mmap = @params.use_mmap; | |||||
| lparams.use_mlock = @params.use_mlock; | |||||
| lparams.logits_all = @params.perplexity; | |||||
| lparams.embedding = @params.embedding; | |||||
| var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams); | |||||
| if(ctx_ptr == IntPtr.Zero ) | |||||
| { | |||||
| throw new RuntimeError($"Failed to load model {@params.model}."); | |||||
| } | |||||
| SafeLLamaContextHandle ctx = new(ctx_ptr); | |||||
| if (!string.IsNullOrEmpty(@params.lora_adapter)) | |||||
| { | |||||
| int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter, | |||||
| string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads); | |||||
| if(err != 0) | |||||
| { | |||||
| throw new RuntimeError("Failed to apply lora adapter."); | |||||
| } | |||||
| } | |||||
| return ctx; | |||||
| } | |||||
| public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos) | |||||
| { | |||||
| llama_token[] res = new llama_token[text.Length + (add_bos ? 1 : 0)]; | |||||
| int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos); | |||||
| Debug.Assert(n >= 0); | |||||
| return res.Take(n).ToList(); | |||||
| } | |||||
| public unsafe static Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length) | |||||
| { | |||||
| var logits = NativeApi.llama_get_logits(ctx); | |||||
| return new Span<float>(logits, length); | |||||
| } | |||||
| } | |||||
| } | |||||