using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using LLama.Types;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Collections;
namespace LLama
{
using llama_token = Int32;
///
/// High-level Wrapper of a llama.cpp model for inference.
///
[Obsolete]
public class LLamaModelV1
{
private string _model_path;
LLamaContextParams _params;
private int _n_threads;
private int _n_batch;
private int _last_n_tokens_size;
private string? _lora_base;
private string? _lora_path;
private bool _verbose;
private Queue _eval_tokens;
private Queue _eval_logits;
private LLamaCache? _cache;
private SafeLLamaContextHandle _ctx;
private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };
///
/// Load a llama.cpp model from the path.
///
/// Note that the API is still unstable. The order of them is likely to
/// be changed in the future. It's recommened to specify the parameter name when
/// building your app. We use the cpp style parameter names here because it introduces
/// convenience for searching the docs.
/// Path to the model.
/// Maximum context size.
/// Number of parts to split the model into. If -1, the number of parts is automatically determined.
/// Random seed. 0 for random.
/// Use half-precision for key/value cache.
/// Return logits for all tokens, not just the last token.
/// Only load the vocabulary no weights.
/// Use mmap if possible.
/// Force the system to keep the model in RAM.
/// Embedding mode only.
/// Number of threads to use. If is not specified, the number of threads is automatically determined.
/// Maximum number of prompt tokens to batch together when calling llama_eval.
/// Maximum number of tokens to keep in the last_n_tokens deque.
/// Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
/// Path to a LoRA file to apply to the model.
/// Print verbose output to stderr.
public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
{
_verbose = verbose;
_model_path = model_path;
_params = NativeApi.llama_context_default_params();
_params.n_ctx = n_ctx;
_params.n_parts = n_parts;
_params.seed = seed;
_params.f16_kv = f16_kv;
_params.logits_all = logits_all;
_params.vocab_only = vocab_only;
_params.use_mmap = lora_path is null ? use_mmap : false;
_params.use_mlock = use_mlock;
_params.embedding = embedding;
_last_n_tokens_size = last_n_tokens_size;
_n_batch = Math.Min(n_ctx, n_batch);
_eval_tokens = new Queue(capacity: n_ctx);
_eval_logits = new Queue(logits_all ? n_ctx : 1);
_cache = null;
_n_threads = n_threads;
if(_n_threads == -1)
{
_n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
}
_lora_base = lora_base;
_lora_path = lora_path;
if(!File.Exists(model_path) && !Directory.Exists(model_path))
{
throw new FileNotFoundException($"Model path does not exist: {model_path}");
}
// Move from heap to stack to prevent the moving.
_ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if(_lora_path is not null)
{
if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
{
throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
}
}
if (_verbose)
{
Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info()));
}
}
public LLamaModelV1(LLamaModelV1 other)
{
_ctx = other._ctx;
_model_path = other._model_path;
_params = other._params;
_last_n_tokens_size = other._last_n_tokens_size;
_n_threads = other._n_threads;
_n_batch = other._n_batch;
_verbose = other._verbose;
_lora_base = other._lora_base;
_lora_path = other._lora_path;
_eval_logits = new Queue(other._eval_logits);
_eval_tokens = new Queue(other._eval_tokens);
}
///
/// Tokenize a string.
///
/// The utf-8 encoded string to tokenize.
/// A list of tokens.
/// If the tokenization failed.
public List Tokenize(string text)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
if(n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
}
return tokens.Take(n_tokens).ToList();
}
///
/// Detokenize a list of tokens.
///
/// The list of tokens to detokenize.
/// The detokenized string.
public string DeTokenize(IEnumerable tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string output = "";
foreach(var token in tokens)
{
output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
}
return output;
}
public string DeTokenize(llama_token token)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? "";
}
///
/// Set the cache.
///
/// The cache to set.
public void SetCache(LLamaCache? cache)
{
_cache = cache;
}
///
/// Reset the model state.
///
public void Reset()
{
_eval_tokens.Clear();
_eval_logits.Clear();
}
///
/// Evaluate a list of tokens.
///
/// The list of tokens to evaluate.
///
public unsafe void Eval(List tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
for(int i = 0; i < tokens.Count; i += _n_batch)
{
var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
llama_token n_tokens = batch.Count();
llama_token return_code = NativeApi.llama_eval(
ctx: _ctx,
tokens: batch.ToArray(),
n_tokens: n_tokens,
n_past: n_past,
n_threads: _n_threads
);
if(return_code != 0)
{
throw new RuntimeError($"llama_eval returned {return_code}");
}
foreach(var b in batch)
{
_eval_tokens.Enqueue(b);
}
int rows = _params.logits_all ? n_tokens : 1;
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var cols = n_vocab;
var logits_view = NativeApi.llama_get_logits(_ctx);
for(int j = 0; j < rows; j++)
{
float[] logit = new float[cols];
for(int k = 0; k < cols; k++)
{
logit[k] = logits_view[j * cols + k];
}
_eval_logits.Enqueue(logit);
}
}
}
private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
Debug.Assert(_eval_logits.Count > 0);
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = _eval_logits.Last();
LLamaTokenData[] data = new LLamaTokenData[n_vocab];
for(int i = 0; i < n_vocab; i++)
{
data[i] = new LLamaTokenData(i, logits[i], .0f);
}
ulong size = (ulong)n_vocab;
bool sorted = false;
LLamaTokenDataArray candidates = new(data, size, sorted);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
repeat_penalty);
//SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
// frequency_penalty, presence_penalty);
if(temp == .0f)
{
return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
}
else
{
SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
return SamplingApi.llama_sample_token(_ctx, candidates);
}
}
///
/// Sample a token from the model.
///
/// The top-k sampling parameter.
/// The top-p sampling parameter.
/// The temperature parameter.
/// The repeat penalty parameter.
///
///
/// The sampled token.
public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
float presence_penalty = .0f)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
.Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
int i = 0;
foreach(var data in last_n_tokens_data)
{
if(i < _last_n_tokens_size)
{
tokens_data[i++] = data;
}
else
{
break;
}
}
return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
}
///
/// Create a generator of tokens from a prompt.
///
///
/// Examples:
/// var llama = new LlamaModel("models/ggml-7b.bin")
/// var tokens = llama.Tokenize(b"Hello, world!")
/// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
/// Console.WriteLine(llama.DeTokenize(new []{token}));
/// }
///
///
///
///
///
///
///
///
///
///
public IEnumerable Generate(IEnumerable tokens, int top_k, float top_p, float temp,
float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if(reset && _eval_tokens.Count > 0)
{
int longest_prefix = 0;
foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
{
if(a == b)
{
longest_prefix += 1;
}
else
{
break;
}
}
if(longest_prefix > 0)
{
if (_verbose)
{
Logger.Default.Info("Llama.generate: prefix-match hit");
}
reset = false;
tokens = tokens.Skip(longest_prefix);
for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
{
_eval_tokens.Dequeue();
if(_eval_logits.Count > 0)
{
_eval_logits.Dequeue();
}
}
}
}
if (reset)
{
Reset();
}
while (true)
{
Eval(tokens.ToList());
var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
yield return token;
// TODO(Rinne): verify if the implementation is correct.
}
}
///
/// Embed a string.
///
/// The utf-8 encoded string to embed.
/// An embedding object.
///
public unsafe Embedding CreateEmbedding(string input)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if (!_params.embedding)
{
throw new RuntimeError("Llama model must be created with embedding=True to call this method");
}
if (_verbose)
{
NativeApi.llama_reset_timings(_ctx);
}
var tokens = Tokenize(input);
Reset();
Eval(tokens);
int n_tokens = tokens.Count;
var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
int cnt = NativeApi.llama_n_embd(_ctx);
float[] embedding = new float[cnt];
for(int i = 0; i < cnt; i++)
{
embedding[i] = embeddingPtr[i];
}
if (_verbose)
{
NativeApi.llama_print_timings(_ctx);
}
return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
new EmbeddingUsage(n_tokens, n_tokens));
}
public float[] Embed(string input)
{
return CreateEmbedding(input).Data[0].Embedding;
}
///
///
///
///
///
///
///
///
///
///
///
///
///
///
///
/// IEnumerable of Completion and CompletionChunk
///
private IEnumerable CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string completionId = $"cmpl-{Guid.NewGuid()}";
var created = DateTime.Now.Millisecond;
List completionTokens = new List();
var promptTokens = Tokenize($" {prompt}");
string text = "";
int returnedCharacters = 0;
if(stop is null)
{
stop = new string[0];
}
if (_verbose)
{
NativeApi.llama_reset_timings(_ctx);
}
if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx))
{
throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}");
}
if(logprobs != -1 && !_params.logits_all)
{
throw new ArgumentException("logprobs is not supported for models created with logits_all=False");
}
if(_cache is not null)
{
try
{
// TODO(Rinne): revise it since it will compare reference instead of elements.
var cacheItem = _cache[promptTokens.ToArray()];
var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
if(cachePrefixLen > evalPrefixLen)
{
LoadState(cacheItem);
if (_verbose)
{
Logger.Default.Info("Llama._create_completion: cache hit");
}
}
}
catch (KeyNotFoundException)
{
if (_verbose)
{
Logger.Default.Warn("Llama._create_completion: cache miss");
}
}
}
string finishReason = "length";
int multibyteFix = 0;
bool reset = true;
List tokens = new(promptTokens);
if (reset && _eval_tokens.Count > 0)
{
int longest_prefix = 0;
foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y)))
{
if (a == b)
{
longest_prefix += 1;
}
else
{
break;
}
}
if (longest_prefix > 0)
{
if (_verbose)
{
Logger.Default.Info("Llama.generate: prefix-match hit");
}
reset = false;
tokens = tokens.Skip(longest_prefix).ToList();
for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
{
_eval_tokens.Dequeue();
if (_eval_logits.Count > 0)
{
_eval_logits.Dequeue();
}
}
}
}
if (reset)
{
Reset();
}
//foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
string allText = "";
while (true)
{
Eval(tokens);
var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
tokens.Clear();
tokens.Add(token);
if (token == NativeApi.llama_token_eos())
{
text = DeTokenize(completionTokens);
finishReason = "stop";
break;
}
completionTokens.Add(token);
allText = DeTokenize(completionTokens);
int cut = Math.Min(3, allText.Length);
for(int i = allText.Length - cut; i < allText.Length; i++)
{
var c = (int)allText[i];
int k = cut - i;
foreach(var (num, pattern) in _numAndPatterns)
{
if(num > k && (pattern & c) == pattern)
{
multibyteFix = num - k;
}
}
}
if(multibyteFix > 0)
{
multibyteFix--;
continue;
}
var anyStop = stop.Where(s => allText.Contains(s));
if(anyStop.Count() > 0)
{
var firstStop = anyStop.First();
text = allText.Substring(0, allText.IndexOf(firstStop));
finishReason = "stop";
break;
}
var start = returnedCharacters;
int longest = 0;
foreach (var s in stop)
{
for (int i = s.Length; i > 0; i--)
{
if (allText.EndsWith(s.Substring(0, i)))
{
if (i > longest)
{
longest = i;
}
break;
}
}
}
text = allText.Substring(0, allText.Length - longest);
returnedCharacters += text.Skip(start).Count();
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(start), 0, null, finishReason)
});
}
if (_cache is not null)
{
if (_verbose)
{
Logger.Default.Info("Llama._create_completion: cache save");
}
_cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
}
string textStr = text;
if (echo)
{
textStr = prompt + textStr;
}
if(suffix is not null)
{
textStr = textStr + suffix;
}
CompletionLogprobs? logProbs = null;
if (logprobs != -1)
{
int textOffset = 0;
List textOffsets = new();
List tokenLogprobs = new();
List tokenStrs = new();
List> topLogprobs = new();
var allTokens = promptTokens.Concat(completionTokens).ToArray();
var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t }));
var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row));
foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y))
.Zip(allLogProbs, (x, y) => (x.x, x.y, y)))
{
textOffsets.Add(textOffset);
textOffset += tokenStr.Length;
tokenStrs.Add(tokenStr);
var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y))
.OrderByDescending(x => x.x).ToList();
tokenLogprobs.Add(sortedLogprobs[token].x);
var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x);
topLogprob[tokenStr] = sortedLogprobs[token].x;
topLogprobs.Add(topLogprob);
}
logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray());
}
if (_verbose)
{
NativeApi.llama_print_timings(_ctx);
}
}
///
/// Generate text from a prompt and yield return the result.
///
/// The prompt to generate text from.
/// A suffix to append to the generated text. If None, no suffix is appended.
/// The maximum number of tokens to generate.
/// The temperature to use for sampling.
/// The top-p value to use for sampling.
/// The number of logprobs to return. If None, no logprobs are returned.
/// Whether to echo the prompt.
/// A list of strings to stop generation when encountered.
///
///
/// The penalty to apply to repeated tokens.
/// The top-k value to use for sampling.
///
public IEnumerable CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}
///
/// Generate text from a prompt and yield return the result.
///
/// The prompt to generate text from.
/// A suffix to append to the generated text. If None, no suffix is appended.
/// The maximum number of tokens to generate.
/// The temperature to use for sampling.
/// The top-p value to use for sampling.
/// The number of logprobs to return. If None, no logprobs are returned.
/// Whether to echo the prompt.
/// A list of strings to stop generation when encountered.
///
///
/// The penalty to apply to repeated tokens.
/// The top-k value to use for sampling.
///
public IEnumerable Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}
private ChatCompletion ConvertTextCompletionToChat(Completion completion)
{
return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text),
completion.Choices[0].FinishReason) }, completion.Usage);
}
private IEnumerable ConvertTextCompletionChunksToChat(IEnumerable chunks)
{
bool isFirst = true;
foreach(var chunk in chunks)
{
if(isFirst)
{
yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) });
isFirst = false;
}
yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text),
chunk.Choices[0].FinishReason) });
}
}
///
/// Generate a chat completion from a list of messages and yield return the result.
///
/// A list of messages to generate a response for.
/// The temperature to use for sampling.
/// The top-p value to use for sampling.
/// The top-k value to use for sampling.
/// A list of strings to stop generation when encountered.
/// The maximum number of tokens to generate.
///
///
/// The penalty to apply to repeated tokens.
///
public IEnumerable CreateChatCompletion(IEnumerable messages, float temperature = .2f, float top_p = .95f,
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
float repeat_penalty = 1.1f)
{
if (stop is null)
{
stop = new string[0];
}
string GetRole(ChatCompletionMessage message)
{
return message.Role == ChatRole.Human ? "Human" : "Assistant";
}
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
var prompt = chatHistory + "### Assistant:";
prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens));
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
return ConvertTextCompletionChunksToChat(completion);
}
public LLamaState SaveState()
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
ulong stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] llamaState = new byte[stateSize];
ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState);
if(nBytes > stateSize)
{
throw new RuntimeError("Failed to copy llama state data");
}
byte[] llamaStateCompact = new byte[nBytes];
llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0);
if (_verbose)
{
Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state");
}
return new LLamaState(new Queue(_eval_tokens), new Queue(_eval_logits),
llamaStateCompact, (int)nBytes);
}
public void LoadState(LLamaState state)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
_eval_tokens = new Queue(state.EvalTokens);
_eval_logits = new Queue(state.EvalLogits);
if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size)
{
throw new RuntimeError($"Failed to set llama state data");
}
}
private static IEnumerable LogitsToLogprobs(IEnumerable logits)
{
var exps = logits.Select(x => (float)Math.Exp(x));
var sumExps = exps.Sum();
return exps.Select(x => (float)Math.Log(x / sumExps));
}
internal static int LongestTokenPrefix(IEnumerable a, IEnumerable b)
{
int longestPrefix = 0;
foreach(var (x, y) in a.Zip(b, (x, y) => (x, y)))
{
if(x == y)
{
longestPrefix++;
}
else
{
break;
}
}
return longestPrefix;
}
}
}