|
- using LLama.Exceptions;
- using LLama.Native;
- using System;
- using System.Collections.Generic;
- using System.Configuration;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Runtime.CompilerServices;
- using System.Text;
- using LLama.Types;
- using System.Runtime.InteropServices;
- using System.Text.RegularExpressions;
-
- namespace LLama
- {
- using llama_token = Int32;
- /// <summary>
- /// High-level Wrapper of a llama.cpp model for inference.
- /// </summary>
- public class LLamaModel
- {
- private string _model_path;
- LLamaContextParams _params;
- private int _n_threads;
- private int _n_batch;
- private int _last_n_tokens_size;
- private string? _lora_base;
- private string? _lora_path;
- private bool _verbose;
-
- private Queue<llama_token> _eval_tokens;
- private Queue<float[]> _eval_logits;
- private LLamaCache? _cache;
- private SafeLLamaContextHandle _ctx;
-
- private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };
-
- /// <summary>
- /// Load a llama.cpp model from the path.
- /// </summary>
- /// <remarks>Note that the API is still unstable. The order of them is likely to
- /// be changed in the future. It's recommened to specify the parameter name when
- /// building your app. We use the cpp style parameter names here because it introduces
- /// convenience for searching the docs.</remarks>
- /// <param name="model_path">Path to the model.</param>
- /// <param name="n_ctx">Maximum context size.</param>
- /// <param name="n_parts">Number of parts to split the model into. If -1, the number of parts is automatically determined.</param>
- /// <param name="seed">Random seed. 0 for random.</param>
- /// <param name="f16_kv">Use half-precision for key/value cache.</param>
- /// <param name="logits_all">Return logits for all tokens, not just the last token.</param>
- /// <param name="vocab_only">Only load the vocabulary no weights.</param>
- /// <param name="use_mmap">Use mmap if possible.</param>
- /// <param name="use_mlock">Force the system to keep the model in RAM.</param>
- /// <param name="embedding">Embedding mode only.</param>
- /// <param name="n_threads">Number of threads to use. If is not specified, the number of threads is automatically determined.</param>
- /// <param name="n_batch">Maximum number of prompt tokens to batch together when calling llama_eval.</param>
- /// <param name="last_n_tokens_size">Maximum number of tokens to keep in the last_n_tokens deque.</param>
- /// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param>
- /// <param name="lora_path">Path to a LoRA file to apply to the model.</param>
- /// <param name="verbose">Print verbose output to stderr.</param>
- public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
- bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
- bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
- int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
- {
- _verbose = verbose;
- _model_path = model_path;
-
- _params = NativeApi.llama_context_default_params();
- _params.n_ctx = n_ctx;
- _params.n_parts = n_parts;
- _params.seed = seed;
- _params.f16_kv = f16_kv;
- _params.logits_all = logits_all;
- _params.vocab_only = vocab_only;
- _params.use_mmap = lora_path is null ? use_mmap : false;
- _params.use_mlock = use_mlock;
- _params.embedding = embedding;
-
- _last_n_tokens_size = last_n_tokens_size;
- _n_batch = Math.Min(n_ctx, n_batch);
-
- _eval_tokens = new Queue<int>(capacity: n_ctx);
- _eval_logits = new Queue<float[]>(logits_all ? n_ctx : 1);
-
- _cache = null;
-
- _n_threads = n_threads;
- if(_n_threads == -1)
- {
- _n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
- }
-
- _lora_base = lora_base;
- _lora_path = lora_path;
-
- if(!File.Exists(model_path) && !Directory.Exists(model_path))
- {
- throw new FileNotFoundException($"Model path does not exist: {model_path}");
- }
-
- // Move from heap to stack to prevent the moving.
- _ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));
-
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
-
- if(_lora_path is not null)
- {
- if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
- {
- throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
- }
- }
-
- if (_verbose)
- {
- #if NET6_0_OR_GREATER
- Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info()));
- #endif
- }
- }
-
- public LLamaModel(LLamaModel other)
- {
- _ctx = other._ctx;
- _model_path = other._model_path;
- _params = other._params;
- _last_n_tokens_size = other._last_n_tokens_size;
- _n_threads = other._n_threads;
- _n_batch = other._n_batch;
- _verbose = other._verbose;
- _lora_base = other._lora_base;
- _lora_path = other._lora_path;
- _eval_logits = new Queue<float[]>(other._eval_logits);
- _eval_tokens = new Queue<llama_token>(other._eval_tokens);
- }
-
- /// <summary>
- /// Tokenize a string.
- /// </summary>
- /// <param name="text">The utf-8 encoded string to tokenize.</param>
- /// <returns>A list of tokens.</returns>
- /// <exception cref="RuntimeError">If the tokenization failed.</exception>
- public List<llama_token> Tokenize(string text)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- var n_ctx = NativeApi.llama_n_ctx(_ctx);
- var tokens = new llama_token[n_ctx];
- var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
- if(n_tokens < 0)
- {
- throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
- }
- return tokens.Take(n_tokens).ToList();
- }
-
- /// <summary>
- /// Detokenize a list of tokens.
- /// </summary>
- /// <param name="tokens">The list of tokens to detokenize.</param>
- /// <returns>The detokenized string.</returns>
- public string DeTokenize(IEnumerable<llama_token> tokens)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- string output = "";
- foreach(var token in tokens)
- {
- #if NET6_0_OR_GREATER
- output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
- #else
- output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token));
- #endif
- }
- return output;
- }
-
- /// <summary>
- /// Set the cache.
- /// </summary>
- /// <param name="cache">The cache to set.</param>
- public void SetCache(LLamaCache? cache)
- {
- _cache = cache;
- }
-
- /// <summary>
- /// Reset the model state.
- /// </summary>
- public void Reset()
- {
- _eval_tokens.Clear();
- _eval_logits.Clear();
- }
-
- /// <summary>
- /// Evaluate a list of tokens.
- /// </summary>
- /// <param name="tokens">The list of tokens to evaluate.</param>
- /// <exception cref="RuntimeError"></exception>
- public unsafe void Eval(List<llama_token> tokens)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- var n_ctx = NativeApi.llama_n_ctx(_ctx);
- for(int i = 0; i < tokens.Count; i += _n_batch)
- {
- var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
- llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
- llama_token n_tokens = batch.Count();
- llama_token return_code = NativeApi.llama_eval(
- ctx: _ctx,
- tokens: batch.ToArray(),
- n_tokens: n_tokens,
- n_past: n_past,
- n_threads: _n_threads
- );
- if(return_code != 0)
- {
- throw new RuntimeError($"llama_eval returned {return_code}");
- }
- foreach(var b in batch)
- {
- _eval_tokens.Enqueue(b);
- }
- int rows = _params.logits_all ? n_tokens : 1;
- llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
- var cols = n_vocab;
- var logits_view = NativeApi.llama_get_logits(_ctx);
- for(int j = 0; j < rows; j++)
- {
- float[] logit = new float[cols];
- for(int k = 0; k < cols; k++)
- {
- logit[k] = logits_view[j * cols + k];
- }
- _eval_logits.Enqueue(logit);
- }
- }
- }
-
- private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
- float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- Debug.Assert(_eval_logits.Count > 0);
- llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
- var logits = _eval_logits.Last();
- LLamaTokenData[] data = new LLamaTokenData[n_vocab];
- for(int i = 0; i < n_vocab; i++)
- {
- data[i] = new LLamaTokenData(i, logits[i], .0f);
- }
- ulong size = (ulong)n_vocab;
- bool sorted = false;
- LLamaTokenDataArray candidates = new(data, size, sorted);
- SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
- repeat_penalty);
- //SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
- // frequency_penalty, presence_penalty);
- if(temp == .0f)
- {
- return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
- }
- else
- {
- SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
- SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
- SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
- SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
- SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
- return SamplingApi.llama_sample_token(_ctx, candidates);
- }
- }
-
- /// <summary>
- /// Sample a token from the model.
- /// </summary>
- /// <param name="top_k">The top-k sampling parameter.</param>
- /// <param name="top_p">The top-p sampling parameter.</param>
- /// <param name="temp">The temperature parameter.</param>
- /// <param name="repeat_penalty">The repeat penalty parameter.</param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <returns>The sampled token.</returns>
- public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
- float presence_penalty = .0f)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
- last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
- .Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
- llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
- int i = 0;
- foreach(var data in last_n_tokens_data)
- {
- if(i < _last_n_tokens_size)
- {
- tokens_data[i++] = data;
- }
- else
- {
- break;
- }
- }
- return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
- }
-
- /// <summary>
- /// Create a generator of tokens from a prompt.
- /// </summary>
- /// <example>
- /// Examples:
- /// var llama = new LlamaModel("models/ggml-7b.bin")
- /// var tokens = llama.Tokenize(b"Hello, world!")
- /// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
- /// Console.WriteLine(llama.DeTokenize(new []{token}));
- /// }
- /// </example>
- /// <param name="tokens"></param>
- /// <param name="top_k"></param>
- /// <param name="top_p"></param>
- /// <param name="temp"></param>
- /// <param name="repeat_penalty"></param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="reset"></param>
- /// <returns></returns>
- public IEnumerable<llama_token> Generate(IEnumerable<llama_token> tokens, int top_k, float top_p, float temp,
- float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- if(reset && _eval_tokens.Count > 0)
- {
- int longest_prefix = 0;
- foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
- {
- if(a == b)
- {
- longest_prefix += 1;
- }
- else
- {
- break;
- }
- }
- if(longest_prefix > 0)
- {
- if (_verbose)
- {
- Logger.Default.Info("Llama.generate: prefix-match hit");
- }
- reset = false;
- tokens = tokens.Skip(longest_prefix);
- for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
- {
- _eval_tokens.Dequeue();
- if(_eval_logits.Count > 0)
- {
- _eval_logits.Dequeue();
- }
- }
- }
- }
-
- if (reset)
- {
- Reset();
- }
-
- while (true)
- {
- Eval(tokens.ToList());
- var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
- yield return token;
- // TODO(Rinne): verify if the implementation is correct.
- }
- }
-
- /// <summary>
- /// Embed a string.
- /// </summary>
- /// <param name="input">The utf-8 encoded string to embed.</param>
- /// <returns>An embedding object.</returns>
- /// <exception cref="RuntimeError"></exception>
- public unsafe Embedding CreateEmbedding(string input)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- if (!_params.embedding)
- {
- throw new RuntimeError("Llama model must be created with embedding=True to call this method");
- }
-
- if (_verbose)
- {
- NativeApi.llama_reset_timings(_ctx);
- }
-
- var tokens = Tokenize(input);
- Reset();
- Eval(tokens);
- int n_tokens = tokens.Count;
- var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
- int cnt = NativeApi.llama_n_embd(_ctx);
- float[] embedding = new float[cnt];
- for(int i = 0; i < cnt; i++)
- {
- embedding[i] = embeddingPtr[i];
- }
-
- if (_verbose)
- {
- NativeApi.llama_print_timings(_ctx);
- }
-
- return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
- new EmbeddingUsage(n_tokens, n_tokens));
- }
-
- public float[] Embed(string input)
- {
- return CreateEmbedding(input).Data[0].Embedding;
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="prompt"></param>
- /// <param name="suffix"></param>
- /// <param name="max_tokens"></param>
- /// <param name="temperature"></param>
- /// <param name="top_p"></param>
- /// <param name="logprobs"></param>
- /// <param name="echo"></param>
- /// <param name="stop"></param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="repeat_penalty"></param>
- /// <param name="top_k"></param>
- /// <param name="stream"></param>
- /// <returns>IEnumerable of Completion and CompletionChunk</returns>
- /// <exception cref="ArgumentException"></exception>
- private IEnumerable<object> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
- float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
- float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40, bool stream = false)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- string completionId = $"cmpl-{Guid.NewGuid()}";
- var created = DateTime.Now.Millisecond;
- List<llama_token> completionTokens = new List<llama_token>();
-
- var promptTokens = Tokenize($" {prompt}");
- string text = "";
- int returnedCharacters = 0;
- if(stop is null)
- {
- stop = new string[0];
- }
-
- if (_verbose)
- {
- NativeApi.llama_reset_timings(_ctx);
- }
-
- if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx))
- {
- throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}");
- }
-
- if(logprobs != -1 && !_params.logits_all)
- {
- throw new ArgumentException("logprobs is not supported for models created with logits_all=False");
- }
-
- if(_cache is not null)
- {
- try
- {
- // TODO(Rinne): revise it since it will compare reference instead of elements.
- var cacheItem = _cache[promptTokens.ToArray()];
- var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
- var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
- if(cachePrefixLen > evalPrefixLen)
- {
- LoadState(cacheItem);
- if (_verbose)
- {
- Logger.Default.Info("Llama._create_completion: cache hit");
- }
- }
- }
- catch (KeyNotFoundException)
- {
- if (_verbose)
- {
- Logger.Default.Warn("Llama._create_completion: cache miss");
- }
- }
- }
-
- string finishReason = "length";
- int multibyteFix = 0;
- bool reset = true;
- List<llama_token> tokens = new(promptTokens);
- if (reset && _eval_tokens.Count > 0)
- {
- int longest_prefix = 0;
- foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y)))
- {
- if (a == b)
- {
- longest_prefix += 1;
- }
- else
- {
- break;
- }
- }
- if (longest_prefix > 0)
- {
- if (_verbose)
- {
- Logger.Default.Info("Llama.generate: prefix-match hit");
- }
- reset = false;
- tokens = tokens.Skip(longest_prefix).ToList();
- for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
- {
- _eval_tokens.Dequeue();
- if (_eval_logits.Count > 0)
- {
- _eval_logits.Dequeue();
- }
- }
- }
- }
-
- if (reset)
- {
- Reset();
- }
- //foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
- while(true)
- {
- Eval(tokens);
- var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
- tokens.Clear();
- tokens.Add(token);
- if (token == NativeApi.llama_token_eos())
- {
- text = DeTokenize(completionTokens);
- finishReason = "stop";
- break;
- }
-
- completionTokens.Add(token);
-
- string allText = DeTokenize(completionTokens);
-
- int cut = Math.Min(3, allText.Length);
- for(int i = allText.Length - cut; i < allText.Length; i++)
- {
- var c = (int)allText[i];
- int k = cut - i;
- foreach(var (num, pattern) in _numAndPatterns)
- {
- if(num > k && (pattern & c) == pattern)
- {
- multibyteFix = num - k;
- }
- }
- }
-
- if(multibyteFix > 0)
- {
- multibyteFix--;
- continue;
- }
-
- var anyStop = stop.Where(s => allText.Contains(s));
- if(anyStop.Count() > 0)
- {
- var firstStop = anyStop.First();
- text = allText.Substring(0, allText.IndexOf(firstStop));
- finishReason = "stop";
- break;
- }
-
- if (stream)
- {
- var start = returnedCharacters;
- int longest = 0;
- foreach(var s in stop)
- {
- for(int i = s.Length; i > 0; i--)
- {
- if(allText.EndsWith(s.Substring(0, i)))
- {
- if(i > longest)
- {
- longest = i;
- }
- break;
- }
- }
- }
- text = allText.Substring(0, allText.Length - longest);
- returnedCharacters += text.Skip(start).Count();
- yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
- {
- new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
- });
- }
- }
-
- if(_cache is not null)
- {
- if (_verbose)
- {
- Logger.Default.Info("Llama._create_completion: cache save");
- }
- _cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
- }
-
- if (stream)
- {
- yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
- {
- new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
- });
- }
-
- string textStr = text;
- if (echo)
- {
- textStr = prompt + textStr;
- }
- if(suffix is not null)
- {
- textStr = textStr + suffix;
- }
-
- CompletionLogprobs? logProbs = null;
- if (logprobs != -1)
- {
- int textOffset = 0;
- List<int> textOffsets = new();
- List<float> tokenLogprobs = new();
- List<string> tokenStrs = new();
- List<Dictionary<string, float>> topLogprobs = new();
-
- var allTokens = promptTokens.Concat(completionTokens).ToArray();
- var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t }));
- var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row));
-
- foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y))
- .Zip(allLogProbs, (x, y) => (x.x, x.y, y)))
- {
- textOffsets.Add(textOffset);
- textOffset += tokenStr.Length;
- tokenStrs.Add(tokenStr);
- var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y))
- .OrderByDescending(x => x.x).ToList();
- tokenLogprobs.Add(sortedLogprobs[token].x);
- var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x);
- topLogprob[tokenStr] = sortedLogprobs[token].x;
- topLogprobs.Add(topLogprob);
- }
-
- logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray());
- }
-
- if (_verbose)
- {
- NativeApi.llama_print_timings(_ctx);
- }
-
- yield return new Completion(completionId, "text_completion", created, _model_path, new CompletionChoice[]
- {
- new CompletionChoice(text, 0, logProbs, finishReason)
- }, new CompletionUsage(promptTokens.Count, completionTokens.Count, promptTokens.Count + completionTokens.Count));
- }
-
- /// <summary>
- /// Generate text from a prompt and yield return the result.
- /// </summary>
- /// <param name="prompt">The prompt to generate text from.</param>
- /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
- /// <param name="echo">Whether to echo the prompt.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <returns></returns>
- public IEnumerable<CompletionChunk> CreateCompletionStream(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
- float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
- float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
- {
- yield return (CompletionChunk)CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
- frequency_penalty, presence_penalty, repeat_penalty, top_k, true);
- }
-
- /// <summary>
- /// Generate text from a prompt.
- /// </summary>
- /// <param name="prompt">The prompt to generate text from.</param>
- /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
- /// <param name="echo">Whether to echo the prompt.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <returns></returns>
- public Completion CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
- float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
- float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
- {
- var completion = CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
- frequency_penalty, presence_penalty, repeat_penalty, top_k, false).First();
- return (Completion)completion;
- }
-
- /// <summary>
- /// Generate text from a prompt.
- /// </summary>
- /// <param name="prompt">The prompt to generate text from.</param>
- /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
- /// <param name="echo">Whether to echo the prompt.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <returns></returns>
- public Completion Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
- float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
- float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
- {
- return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
- frequency_penalty, presence_penalty, repeat_penalty, top_k);
- }
-
- /// <summary>
- /// Generate text from a prompt and yield return the result.
- /// </summary>
- /// <param name="prompt">The prompt to generate text from.</param>
- /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
- /// <param name="echo">Whether to echo the prompt.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="frequency_penalty"></param>
- /// <param name="presence_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <returns></returns>
- public IEnumerable<CompletionChunk> StreamCall(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
- float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
- float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
- {
- return CreateCompletionStream(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
- frequency_penalty, presence_penalty, repeat_penalty, top_k);
- }
-
- private ChatCompletion ConvertTextCompletionToChat(Completion completion)
- {
- return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
- new[] { new ChatCompletionChoice(0, new ChatCompletionMessage("assistant", completion.Choices[0].Text, null),
- completion.Choices[0].FinishReason) }, completion.Usage);
- }
-
- private IEnumerable<ChatCompletionChunk> ConvertTextCompletionChunksToChat(IEnumerable<CompletionChunk> chunks)
- {
- bool isFirst = true;
- foreach(var chunk in chunks)
- {
- if(isFirst)
- {
- yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
- new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) });
- isFirst = false;
- }
- yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
- new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text),
- chunk.Choices[0].FinishReason) });
- }
- }
-
- /// <summary>
- /// Generate a chat completion from a list of messages.
- /// </summary>
- /// <param name="messages">A list of messages to generate a response for.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="presence_penalty"></param>
- /// <param name="frequency_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <returns></returns>
- public ChatCompletion CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
- int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
- float repeat_penalty = 1.1f)
- {
- if(stop is null)
- {
- stop = new string[0];
- }
- string GetRole(ChatCompletionMessage message)
- {
- return message.Role == "user" ? "Human" : "Assistant";
- }
- string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
- var prompt = chatHistory + "### Assistant:";
- var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
- var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
- repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
- return ConvertTextCompletionToChat(completion);
- }
-
- /// <summary>
- /// Generate a chat completion from a list of messages and yield return the result.
- /// </summary>
- /// <param name="messages">A list of messages to generate a response for.</param>
- /// <param name="temperature">The temperature to use for sampling.</param>
- /// <param name="top_p">The top-p value to use for sampling.</param>
- /// <param name="top_k">The top-k value to use for sampling.</param>
- /// <param name="stop">A list of strings to stop generation when encountered.</param>
- /// <param name="max_tokens">The maximum number of tokens to generate.</param>
- /// <param name="presence_penalty"></param>
- /// <param name="frequency_penalty"></param>
- /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
- /// <returns></returns>
- public IEnumerable<ChatCompletionChunk> CreateChatCompletionStream(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
- int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
- float repeat_penalty = 1.1f)
- {
- if (stop is null)
- {
- stop = new string[0];
- }
- string GetRole(ChatCompletionMessage message)
- {
- return message.Role == "user" ? "Human" : "Assistant";
- }
- string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
- var prompt = chatHistory + "### Assistant:";
- var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
- var completion = StreamCall(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
- repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
- return ConvertTextCompletionChunksToChat(completion);
- }
-
- public LLamaState SaveState()
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- ulong stateSize = NativeApi.llama_get_state_size(_ctx);
- byte[] llamaState = new byte[stateSize];
- ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState);
- if(nBytes > stateSize)
- {
- throw new RuntimeError("Failed to copy llama state data");
- }
- byte[] llamaStateCompact = new byte[nBytes];
- llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0);
- if (_verbose)
- {
- Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state");
- }
- return new LLamaState(new Queue<llama_token>(_eval_tokens), new Queue<float[]>(_eval_logits),
- llamaStateCompact, (int)nBytes);
- }
-
- public void LoadState(LLamaState state)
- {
- Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
- _eval_tokens = new Queue<llama_token>(state.EvalTokens);
- _eval_logits = new Queue<float[]>(state.EvalLogits);
- if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size)
- {
- throw new RuntimeError($"Failed to set llama state data");
- }
- }
-
- private static IEnumerable<float> LogitsToLogprobs(IEnumerable<float> logits)
- {
- var exps = logits.Select(x => (float)Math.Exp(x));
- var sumExps = exps.Sum();
- return exps.Select(x => (float)Math.Log(x / sumExps));
- }
-
- internal static int LongestTokenPrefix(IEnumerable<llama_token> a, IEnumerable<llama_token> b)
- {
- int longestPrefix = 0;
- foreach(var (x, y) in a.Zip(b, (x, y) => (x, y)))
- {
- if(x == y)
- {
- longestPrefix++;
- }
- else
- {
- break;
- }
- }
- return longestPrefix;
- }
- }
- }
|