using LLama.Exceptions; using LLama.Native; using System; using System.Collections.Generic; using System.Configuration; using System.Diagnostics; using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Text; using LLama.Types; using System.Runtime.InteropServices; using System.Text.RegularExpressions; namespace LLama { using llama_token = Int32; /// /// High-level Wrapper of a llama.cpp model for inference. /// public class LLamaModel { private string _model_path; LLamaContextParams _params; private int _n_threads; private int _n_batch; private int _last_n_tokens_size; private string? _lora_base; private string? _lora_path; private bool _verbose; private Queue _eval_tokens; private Queue _eval_logits; private LLamaCache? _cache; private SafeLLamaContextHandle _ctx; private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) }; /// /// Load a llama.cpp model from the path. /// /// Note that the API is still unstable. The order of them is likely to /// be changed in the future. It's recommened to specify the parameter name when /// building your app. We use the cpp style parameter names here because it introduces /// convenience for searching the docs. /// Path to the model. /// Maximum context size. /// Number of parts to split the model into. If -1, the number of parts is automatically determined. /// Random seed. 0 for random. /// Use half-precision for key/value cache. /// Return logits for all tokens, not just the last token. /// Only load the vocabulary no weights. /// Use mmap if possible. /// Force the system to keep the model in RAM. /// Embedding mode only. /// Number of threads to use. If is not specified, the number of threads is automatically determined. /// Maximum number of prompt tokens to batch together when calling llama_eval. /// Maximum number of tokens to keep in the last_n_tokens deque. /// Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. /// Path to a LoRA file to apply to the model. /// Print verbose output to stderr. public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337, bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true, bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512, int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true) { _verbose = verbose; _model_path = model_path; _params = NativeApi.llama_context_default_params(); _params.n_ctx = n_ctx; _params.n_parts = n_parts; _params.seed = seed; _params.f16_kv = f16_kv; _params.logits_all = logits_all; _params.vocab_only = vocab_only; _params.use_mmap = lora_path is null ? use_mmap : false; _params.use_mlock = use_mlock; _params.embedding = embedding; _last_n_tokens_size = last_n_tokens_size; _n_batch = Math.Min(n_ctx, n_batch); _eval_tokens = new Queue(capacity: n_ctx); _eval_logits = new Queue(logits_all ? n_ctx : 1); _cache = null; _n_threads = n_threads; if(_n_threads == -1) { _n_threads = Math.Max(Environment.ProcessorCount / 2, 1); } _lora_base = lora_base; _lora_path = lora_path; if(!File.Exists(model_path) && !Directory.Exists(model_path)) { throw new FileNotFoundException($"Model path does not exist: {model_path}"); } // Move from heap to stack to prevent the moving. _ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params)); Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); if(_lora_path is not null) { if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0) { throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}"); } } if (_verbose) { #if NET6_0_OR_GREATER Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info())); #endif } } public LLamaModel(LLamaModel other) { _ctx = other._ctx; _model_path = other._model_path; _params = other._params; _last_n_tokens_size = other._last_n_tokens_size; _n_threads = other._n_threads; _n_batch = other._n_batch; _verbose = other._verbose; _lora_base = other._lora_base; _lora_path = other._lora_path; _eval_logits = new Queue(other._eval_logits); _eval_tokens = new Queue(other._eval_tokens); } /// /// Tokenize a string. /// /// The utf-8 encoded string to tokenize. /// A list of tokens. /// If the tokenization failed. public List Tokenize(string text) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); var n_ctx = NativeApi.llama_n_ctx(_ctx); var tokens = new llama_token[n_ctx]; var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true); if(n_tokens < 0) { throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}"); } return tokens.Take(n_tokens).ToList(); } /// /// Detokenize a list of tokens. /// /// The list of tokens to detokenize. /// The detokenized string. public string DeTokenize(IEnumerable tokens) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); string output = ""; foreach(var token in tokens) { #if NET6_0_OR_GREATER output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)); #else output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token)); #endif } return output; } /// /// Set the cache. /// /// The cache to set. public void SetCache(LLamaCache? cache) { _cache = cache; } /// /// Reset the model state. /// public void Reset() { _eval_tokens.Clear(); _eval_logits.Clear(); } /// /// Evaluate a list of tokens. /// /// The list of tokens to evaluate. /// public unsafe void Eval(List tokens) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); var n_ctx = NativeApi.llama_n_ctx(_ctx); for(int i = 0; i < tokens.Count; i += _n_batch) { var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i); llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count); llama_token n_tokens = batch.Count(); llama_token return_code = NativeApi.llama_eval( ctx: _ctx, tokens: batch.ToArray(), n_tokens: n_tokens, n_past: n_past, n_threads: _n_threads ); if(return_code != 0) { throw new RuntimeError($"llama_eval returned {return_code}"); } foreach(var b in batch) { _eval_tokens.Enqueue(b); } int rows = _params.logits_all ? n_tokens : 1; llama_token n_vocab = NativeApi.llama_n_vocab(_ctx); var cols = n_vocab; var logits_view = NativeApi.llama_get_logits(_ctx); for(int j = 0; j < rows; j++) { float[] logit = new float[cols]; for(int k = 0; k < cols; k++) { logit[k] = logits_view[j * cols + k]; } _eval_logits.Enqueue(logit); } } } private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); Debug.Assert(_eval_logits.Count > 0); llama_token n_vocab = NativeApi.llama_n_vocab(_ctx); var logits = _eval_logits.Last(); LLamaTokenData[] data = new LLamaTokenData[n_vocab]; for(int i = 0; i < n_vocab; i++) { data[i] = new LLamaTokenData(i, logits[i], .0f); } ulong size = (ulong)n_vocab; bool sorted = false; LLamaTokenDataArray candidates = new(data, size, sorted); SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size, repeat_penalty); //SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size, // frequency_penalty, presence_penalty); if(temp == .0f) { return SamplingApi.llama_sample_token_greedy(_ctx, candidates); } else { SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1); SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1); SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1); SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1); SamplingApi.llama_sample_temperature(_ctx, candidates, temp); return SamplingApi.llama_sample_token(_ctx, candidates); } } /// /// Sample a token from the model. /// /// The top-k sampling parameter. /// The top-p sampling parameter. /// The temperature parameter. /// The repeat penalty parameter. /// /// /// The sampled token. public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count)); last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList() .Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size))); llama_token[] tokens_data = new llama_token[_last_n_tokens_size]; int i = 0; foreach(var data in last_n_tokens_data) { if(i < _last_n_tokens_size) { tokens_data[i++] = data; } else { break; } } return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty); } /// /// Create a generator of tokens from a prompt. /// /// /// Examples: /// var llama = new LlamaModel("models/ggml-7b.bin") /// var tokens = llama.Tokenize(b"Hello, world!") /// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){ /// Console.WriteLine(llama.DeTokenize(new []{token})); /// } /// /// /// /// /// /// /// /// /// /// public IEnumerable Generate(IEnumerable tokens, int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); if(reset && _eval_tokens.Count > 0) { int longest_prefix = 0; foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y))) { if(a == b) { longest_prefix += 1; } else { break; } } if(longest_prefix > 0) { if (_verbose) { Logger.Default.Info("Llama.generate: prefix-match hit"); } reset = false; tokens = tokens.Skip(longest_prefix); for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++) { _eval_tokens.Dequeue(); if(_eval_logits.Count > 0) { _eval_logits.Dequeue(); } } } } if (reset) { Reset(); } while (true) { Eval(tokens.ToList()); var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty); yield return token; // TODO(Rinne): verify if the implementation is correct. } } /// /// Embed a string. /// /// The utf-8 encoded string to embed. /// An embedding object. /// public unsafe Embedding CreateEmbedding(string input) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); if (!_params.embedding) { throw new RuntimeError("Llama model must be created with embedding=True to call this method"); } if (_verbose) { NativeApi.llama_reset_timings(_ctx); } var tokens = Tokenize(input); Reset(); Eval(tokens); int n_tokens = tokens.Count; var embeddingPtr = NativeApi.llama_get_embeddings(_ctx); int cnt = NativeApi.llama_n_embd(_ctx); float[] embedding = new float[cnt]; for(int i = 0; i < cnt; i++) { embedding[i] = embeddingPtr[i]; } if (_verbose) { NativeApi.llama_print_timings(_ctx); } return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) }, new EmbeddingUsage(n_tokens, n_tokens)); } public float[] Embed(string input) { return CreateEmbedding(input).Data[0].Embedding; } /// /// /// /// /// /// /// /// /// /// /// /// /// /// /// /// /// IEnumerable of Completion and CompletionChunk /// private IEnumerable CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f, float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40, bool stream = false) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); string completionId = $"cmpl-{Guid.NewGuid()}"; var created = DateTime.Now.Millisecond; List completionTokens = new List(); var promptTokens = Tokenize($" {prompt}"); string text = ""; int returnedCharacters = 0; if(stop is null) { stop = new string[0]; } if (_verbose) { NativeApi.llama_reset_timings(_ctx); } if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx)) { throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}"); } if(logprobs != -1 && !_params.logits_all) { throw new ArgumentException("logprobs is not supported for models created with logits_all=False"); } if(_cache is not null) { try { // TODO(Rinne): revise it since it will compare reference instead of elements. var cacheItem = _cache[promptTokens.ToArray()]; var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens); var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens); if(cachePrefixLen > evalPrefixLen) { LoadState(cacheItem); if (_verbose) { Logger.Default.Info("Llama._create_completion: cache hit"); } } } catch (KeyNotFoundException) { if (_verbose) { Logger.Default.Warn("Llama._create_completion: cache miss"); } } } string finishReason = "length"; int multibyteFix = 0; bool reset = true; List tokens = new(promptTokens); if (reset && _eval_tokens.Count > 0) { int longest_prefix = 0; foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y))) { if (a == b) { longest_prefix += 1; } else { break; } } if (longest_prefix > 0) { if (_verbose) { Logger.Default.Info("Llama.generate: prefix-match hit"); } reset = false; tokens = tokens.Skip(longest_prefix).ToList(); for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++) { _eval_tokens.Dequeue(); if (_eval_logits.Count > 0) { _eval_logits.Dequeue(); } } } } if (reset) { Reset(); } //foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty)) while(true) { Eval(tokens); var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty); tokens.Clear(); tokens.Add(token); if (token == NativeApi.llama_token_eos()) { text = DeTokenize(completionTokens); finishReason = "stop"; break; } completionTokens.Add(token); string allText = DeTokenize(completionTokens); int cut = Math.Min(3, allText.Length); for(int i = allText.Length - cut; i < allText.Length; i++) { var c = (int)allText[i]; int k = cut - i; foreach(var (num, pattern) in _numAndPatterns) { if(num > k && (pattern & c) == pattern) { multibyteFix = num - k; } } } if(multibyteFix > 0) { multibyteFix--; continue; } var anyStop = stop.Where(s => allText.Contains(s)); if(anyStop.Count() > 0) { var firstStop = anyStop.First(); text = allText.Substring(0, allText.IndexOf(firstStop)); finishReason = "stop"; break; } if (stream) { var start = returnedCharacters; int longest = 0; foreach(var s in stop) { for(int i = s.Length; i > 0; i--) { if(allText.EndsWith(s.Substring(0, i))) { if(i > longest) { longest = i; } break; } } } text = allText.Substring(0, allText.Length - longest); returnedCharacters += text.Skip(start).Count(); yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] { new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) }); } } if(_cache is not null) { if (_verbose) { Logger.Default.Info("Llama._create_completion: cache save"); } _cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState(); } if (stream) { yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] { new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) }); } string textStr = text; if (echo) { textStr = prompt + textStr; } if(suffix is not null) { textStr = textStr + suffix; } CompletionLogprobs? logProbs = null; if (logprobs != -1) { int textOffset = 0; List textOffsets = new(); List tokenLogprobs = new(); List tokenStrs = new(); List> topLogprobs = new(); var allTokens = promptTokens.Concat(completionTokens).ToArray(); var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t })); var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row)); foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y)) .Zip(allLogProbs, (x, y) => (x.x, x.y, y))) { textOffsets.Add(textOffset); textOffset += tokenStr.Length; tokenStrs.Add(tokenStr); var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y)) .OrderByDescending(x => x.x).ToList(); tokenLogprobs.Add(sortedLogprobs[token].x); var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x); topLogprob[tokenStr] = sortedLogprobs[token].x; topLogprobs.Add(topLogprob); } logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray()); } if (_verbose) { NativeApi.llama_print_timings(_ctx); } yield return new Completion(completionId, "text_completion", created, _model_path, new CompletionChoice[] { new CompletionChoice(text, 0, logProbs, finishReason) }, new CompletionUsage(promptTokens.Count, completionTokens.Count, promptTokens.Count + completionTokens.Count)); } /// /// Generate text from a prompt and yield return the result. /// /// The prompt to generate text from. /// A suffix to append to the generated text. If None, no suffix is appended. /// The maximum number of tokens to generate. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The number of logprobs to return. If None, no logprobs are returned. /// Whether to echo the prompt. /// A list of strings to stop generation when encountered. /// /// /// The penalty to apply to repeated tokens. /// The top-k value to use for sampling. /// public IEnumerable CreateCompletionStream(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) { yield return (CompletionChunk)CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, frequency_penalty, presence_penalty, repeat_penalty, top_k, true); } /// /// Generate text from a prompt. /// /// The prompt to generate text from. /// A suffix to append to the generated text. If None, no suffix is appended. /// The maximum number of tokens to generate. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The number of logprobs to return. If None, no logprobs are returned. /// Whether to echo the prompt. /// A list of strings to stop generation when encountered. /// /// /// The penalty to apply to repeated tokens. /// The top-k value to use for sampling. /// public Completion CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) { var completion = CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, frequency_penalty, presence_penalty, repeat_penalty, top_k, false).First(); return (Completion)completion; } /// /// Generate text from a prompt. /// /// The prompt to generate text from. /// A suffix to append to the generated text. If None, no suffix is appended. /// The maximum number of tokens to generate. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The number of logprobs to return. If None, no logprobs are returned. /// Whether to echo the prompt. /// A list of strings to stop generation when encountered. /// /// /// The penalty to apply to repeated tokens. /// The top-k value to use for sampling. /// public Completion Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) { return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, frequency_penalty, presence_penalty, repeat_penalty, top_k); } /// /// Generate text from a prompt and yield return the result. /// /// The prompt to generate text from. /// A suffix to append to the generated text. If None, no suffix is appended. /// The maximum number of tokens to generate. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The number of logprobs to return. If None, no logprobs are returned. /// Whether to echo the prompt. /// A list of strings to stop generation when encountered. /// /// /// The penalty to apply to repeated tokens. /// The top-k value to use for sampling. /// public IEnumerable StreamCall(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) { return CreateCompletionStream(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, frequency_penalty, presence_penalty, repeat_penalty, top_k); } private ChatCompletion ConvertTextCompletionToChat(Completion completion) { return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model, new[] { new ChatCompletionChoice(0, new ChatCompletionMessage("assistant", completion.Choices[0].Text, null), completion.Choices[0].FinishReason) }, completion.Usage); } private IEnumerable ConvertTextCompletionChunksToChat(IEnumerable chunks) { bool isFirst = true; foreach(var chunk in chunks) { if(isFirst) { yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created, new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) }); isFirst = false; } yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created, new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text), chunk.Choices[0].FinishReason) }); } } /// /// Generate a chat completion from a list of messages. /// /// A list of messages to generate a response for. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The top-k value to use for sampling. /// A list of strings to stop generation when encountered. /// The maximum number of tokens to generate. /// /// /// The penalty to apply to repeated tokens. /// public ChatCompletion CreateChatCompletion(IEnumerable messages, float temperature = .2f, float top_p = .95f, int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f, float repeat_penalty = 1.1f) { if(stop is null) { stop = new string[0]; } string GetRole(ChatCompletionMessage message) { return message.Role == "user" ? "Human" : "Assistant"; } string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}")); var prompt = chatHistory + "### Assistant:"; var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray(); var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty); return ConvertTextCompletionToChat(completion); } /// /// Generate a chat completion from a list of messages and yield return the result. /// /// A list of messages to generate a response for. /// The temperature to use for sampling. /// The top-p value to use for sampling. /// The top-k value to use for sampling. /// A list of strings to stop generation when encountered. /// The maximum number of tokens to generate. /// /// /// The penalty to apply to repeated tokens. /// public IEnumerable CreateChatCompletionStream(IEnumerable messages, float temperature = .2f, float top_p = .95f, int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f, float repeat_penalty = 1.1f) { if (stop is null) { stop = new string[0]; } string GetRole(ChatCompletionMessage message) { return message.Role == "user" ? "Human" : "Assistant"; } string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}")); var prompt = chatHistory + "### Assistant:"; var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray(); var completion = StreamCall(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty); return ConvertTextCompletionChunksToChat(completion); } public LLamaState SaveState() { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); ulong stateSize = NativeApi.llama_get_state_size(_ctx); byte[] llamaState = new byte[stateSize]; ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState); if(nBytes > stateSize) { throw new RuntimeError("Failed to copy llama state data"); } byte[] llamaStateCompact = new byte[nBytes]; llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0); if (_verbose) { Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state"); } return new LLamaState(new Queue(_eval_tokens), new Queue(_eval_logits), llamaStateCompact, (int)nBytes); } public void LoadState(LLamaState state) { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); _eval_tokens = new Queue(state.EvalTokens); _eval_logits = new Queue(state.EvalLogits); if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size) { throw new RuntimeError($"Failed to set llama state data"); } } private static IEnumerable LogitsToLogprobs(IEnumerable logits) { var exps = logits.Select(x => (float)Math.Exp(x)); var sumExps = exps.Sum(); return exps.Select(x => (float)Math.Log(x / sumExps)); } internal static int LongestTokenPrefix(IEnumerable a, IEnumerable b) { int longestPrefix = 0; foreach(var (x, y) in a.Zip(b, (x, y) => (x, y))) { if(x == y) { longestPrefix++; } else { break; } } return longestPrefix; } } }