using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using LLama.Types;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
namespace LLama
{
using llama_token = Int32;
///
/// High-level Wrapper of a llama.cpp model for inference.
///
public class LLamaModel
{
private string _model_path;
LLamaContextParams _params;
private int _n_threads;
private int _n_batch;
private int _last_n_tokens_size;
private string? _lora_base;
private string? _lora_path;
private bool _verbose;
private Queue _eval_tokens;
private Queue _eval_logits;
private LLamaCache? _cache;
private SafeLLamaContextHandle _ctx;
private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };
///
/// Load a llama.cpp model from the path.
///
/// Note that the API is still unstable. The order of them is likely to
/// be changed in the future. It's recommened to specify the parameter name when
/// building your app. We use the cpp style parameter names here because it introduces
/// convenience for searching the docs.
/// Path to the model.
/// Maximum context size.
/// Number of parts to split the model into. If -1, the number of parts is automatically determined.
/// Random seed. 0 for random.
/// Use half-precision for key/value cache.
/// Return logits for all tokens, not just the last token.
/// Only load the vocabulary no weights.
/// Use mmap if possible.
/// Force the system to keep the model in RAM.
/// Embedding mode only.
/// Number of threads to use. If is not specified, the number of threads is automatically determined.
/// Maximum number of prompt tokens to batch together when calling llama_eval.
/// Maximum number of tokens to keep in the last_n_tokens deque.
/// Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
/// Path to a LoRA file to apply to the model.
/// Print verbose output to stderr.
public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
{
_verbose = verbose;
_model_path = model_path;
_params = NativeApi.llama_context_default_params();
_params.n_ctx = n_ctx;
_params.n_parts = n_parts;
_params.seed = seed;
_params.f16_kv = f16_kv;
_params.logits_all = logits_all;
_params.vocab_only = vocab_only;
_params.use_mmap = lora_path is null ? use_mmap : false;
_params.use_mlock = use_mlock;
_params.embedding = embedding;
_last_n_tokens_size = last_n_tokens_size;
_n_batch = Math.Min(n_ctx, n_batch);
_eval_tokens = new Queue(capacity: n_ctx);
_eval_logits = new Queue(logits_all ? n_ctx : 1);
_cache = null;
_n_threads = n_threads;
if(_n_threads == -1)
{
_n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
}
_lora_base = lora_base;
_lora_path = lora_path;
if(!File.Exists(model_path) && !Directory.Exists(model_path))
{
throw new FileNotFoundException($"Model path does not exist: {model_path}");
}
// Move from heap to stack to prevent the moving.
_ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if(_lora_path is not null)
{
if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
{
throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
}
}
if (_verbose)
{
#if NET6_0_OR_GREATER
Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info()));
#endif
}
}
public LLamaModel(LLamaModel other)
{
_ctx = other._ctx;
_model_path = other._model_path;
_params = other._params;
_last_n_tokens_size = other._last_n_tokens_size;
_n_threads = other._n_threads;
_n_batch = other._n_batch;
_verbose = other._verbose;
_lora_base = other._lora_base;
_lora_path = other._lora_path;
_eval_logits = new Queue(other._eval_logits);
_eval_tokens = new Queue(other._eval_tokens);
}
///
/// Tokenize a string.
///
/// The utf-8 encoded string to tokenize.
/// A list of tokens.
/// If the tokenization failed.
public List Tokenize(string text)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
if(n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
}
return tokens.Take(n_tokens).ToList();
}
///
/// Detokenize a list of tokens.
///
/// The list of tokens to detokenize.
/// The detokenized string.
public string DeTokenize(IEnumerable tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string output = "";
foreach(var token in tokens)
{
#if NET6_0_OR_GREATER
output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
#else
output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token));
#endif
}
return output;
}
///
/// Set the cache.
///
/// The cache to set.
public void SetCache(LLamaCache? cache)
{
_cache = cache;
}
///
/// Reset the model state.
///
public void Reset()
{
_eval_tokens.Clear();
_eval_logits.Clear();
}
///
/// Evaluate a list of tokens.
///
/// The list of tokens to evaluate.
///
public unsafe void Eval(List tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
for(int i = 0; i < tokens.Count; i += _n_batch)
{
var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
llama_token n_tokens = batch.Count();
llama_token return_code = NativeApi.llama_eval(
ctx: _ctx,
tokens: batch.ToArray(),
n_tokens: n_tokens,
n_past: n_past,
n_threads: _n_threads
);
if(return_code != 0)
{
throw new RuntimeError($"llama_eval returned {return_code}");
}
foreach(var b in batch)
{
_eval_tokens.Enqueue(b);
}
int rows = _params.logits_all ? n_tokens : 1;
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var cols = n_vocab;
var logits_view = NativeApi.llama_get_logits(_ctx);
for(int j = 0; j < rows; j++)
{
float[] logit = new float[cols];
for(int k = 0; k < cols; k++)
{
logit[k] = logits_view[j * cols + k];
}
_eval_logits.Enqueue(logit);
}
}
}
private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
Debug.Assert(_eval_logits.Count > 0);
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = _eval_logits.Last();
LLamaTokenData[] data = new LLamaTokenData[n_vocab];
for(int i = 0; i < n_vocab; i++)
{
data[i] = new LLamaTokenData(i, logits[i], .0f);
}
ulong size = (ulong)n_vocab;
bool sorted = false;
LLamaTokenDataArray candidates = new(data, size, sorted);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
repeat_penalty);
//SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
// frequency_penalty, presence_penalty);
if(temp == .0f)
{
return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
}
else
{
SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
return SamplingApi.llama_sample_token(_ctx, candidates);
}
}
///
/// Sample a token from the model.
///
/// The top-k sampling parameter.
/// The top-p sampling parameter.
/// The temperature parameter.
/// The repeat penalty parameter.
///
///
/// The sampled token.
public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
float presence_penalty = .0f)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
.Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
int i = 0;
foreach(var data in last_n_tokens_data)
{
if(i < _last_n_tokens_size)
{
tokens_data[i++] = data;
}
else
{
break;
}
}
return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
}
///
/// Create a generator of tokens from a prompt.
///
///
/// Examples:
/// var llama = new LlamaModel("models/ggml-7b.bin")
/// var tokens = llama.Tokenize(b"Hello, world!")
/// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
/// Console.WriteLine(llama.DeTokenize(new []{token}));
/// }
///
///
///
///
///
///
///
///
///
///
public IEnumerable Generate(IEnumerable tokens, int top_k, float top_p, float temp,
float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if(reset && _eval_tokens.Count > 0)
{
int longest_prefix = 0;
foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
{
if(a == b)
{
longest_prefix += 1;
}
else
{
break;
}
}
if(longest_prefix > 0)
{
if (_verbose)
{
Logger.Default.Info("Llama.generate: prefix-match hit");
}
reset = false;
tokens = tokens.Skip(longest_prefix);
for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
{
_eval_tokens.Dequeue();
if(_eval_logits.Count > 0)
{
_eval_logits.Dequeue();
}
}
}
}
if (reset)
{
Reset();
}
while (true)
{
Eval(tokens.ToList());
var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
yield return token;
// TODO(Rinne): verify if the implementation is correct.
}
}
///
/// Embed a string.
///
/// The utf-8 encoded string to embed.
/// An embedding object.
///
public unsafe Embedding CreateEmbedding(string input)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if (!_params.embedding)
{
throw new RuntimeError("Llama model must be created with embedding=True to call this method");
}
if (_verbose)
{
NativeApi.llama_reset_timings(_ctx);
}
var tokens = Tokenize(input);
Reset();
Eval(tokens);
int n_tokens = tokens.Count;
var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
int cnt = NativeApi.llama_n_embd(_ctx);
float[] embedding = new float[cnt];
for(int i = 0; i < cnt; i++)
{
embedding[i] = embeddingPtr[i];
}
if (_verbose)
{
NativeApi.llama_print_timings(_ctx);
}
return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
new EmbeddingUsage(n_tokens, n_tokens));
}
public float[] Embed(string input)
{
return CreateEmbedding(input).Data[0].Embedding;
}
///
///
///
///
///
///
///
///
///
///
///
///
///
///
///
///
/// IEnumerable of Completion and CompletionChunk
///
private IEnumerable