using LLama.Exceptions; using LLama.Native; using LLama.OldVersion; using LLama.Extensions; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; using System.IO; using LLama.Common; namespace LLama { using llama_token = Int32; /// /// The abstraction of a LLama model, which holds the context in the native library. /// public class LLamaModel: IDisposable { // TODO: expose more properties. ILLamaLogger? _logger; Encoding _encoding; SafeLLamaContextHandle _ctx; /// /// The context size. /// public int ContextSize { get; } /// /// The model params set for this model. /// public ModelParams Params { get; set; } /// /// The native handle, which is used to be passed to the native APIs. Please avoid using it /// unless you know what is the usage of the Native API. /// public SafeLLamaContextHandle NativeHandle => _ctx; /// /// The encoding set for this model to deal with text input. /// public Encoding Encoding => _encoding; /// /// /// /// Model params. /// Encoding to deal with text input. /// The logger. public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null) { _logger = logger; this.Params = Params; _encoding = Encoding.GetEncoding(encoding); _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); _ctx = Utils.InitLLamaContextFromModelParams(this.Params); ContextSize = NativeApi.llama_n_ctx(_ctx); } /// /// Tokenize a string. /// /// /// Whether to add a bos to the text. /// public IEnumerable Tokenize(string text, bool addBos = true) { // TODO: reconsider whether to convert to array here. return Utils.Tokenize(_ctx, text, addBos, _encoding); } /// /// Detokenize the tokens to text. /// /// /// public string DeTokenize(IEnumerable tokens) { StringBuilder sb = new(); foreach(var token in tokens) { sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding)); } return sb.ToString(); } /// /// Save the state to specified path. /// /// public void SaveState(string filename) { File.WriteAllBytes(filename, GetStateData()); } /// /// Get the state data as a byte array. /// /// public byte[] GetStateData() { var stateSize = NativeApi.llama_get_state_size(_ctx); byte[] stateMemory = new byte[stateSize]; NativeApi.llama_copy_state_data(_ctx, stateMemory); return stateMemory; } /// /// Load the state from specified path. /// /// /// public void LoadState(string filename) { var stateMemory = File.ReadAllBytes(filename); LoadState(stateMemory); } /// /// Load the state from memory. /// /// /// public void LoadState(byte[] stateData) { int stateSize = (int)NativeApi.llama_get_state_size(_ctx); if (stateData.Length != stateSize) { throw new RuntimeError("Failed to validate state size."); } NativeApi.llama_set_state_data(_ctx, stateData); } /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// /// /// /// /// /// /// /// /// /// /// public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable, float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) { llama_token id = 0; if (temperature <= 0) { // Greedy sampling id = SamplingApi.llama_sample_token_greedy(_ctx, candidates); } else { if (mirostat == MiroStateType.MiroState) { float mirostat_mu = 2.0f * mirostatTau; const int mirostat_m = 100; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); } else if (mirostat == MiroStateType.MiroState2) { float mirostat_mu = 2.0f * mirostatTau; SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); } else { // Temperature sampling SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); id = SamplingApi.llama_sample_token(_ctx, candidates); } } return id; } /// /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. /// /// /// /// /// /// /// /// /// public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { var n_vocab = NativeApi.llama_n_vocab(_ctx); var logits = Utils.GetLogits(_ctx, n_vocab); // Apply params.logit_bias map if(logitBias is not null) { foreach (var (key, value) in logitBias) { logits[key] += value; } } var candidates = new List(); candidates.Capacity = n_vocab; for (llama_token token_id = 0; token_id < n_vocab; token_id++) { candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); } LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); // Apply penalties float nl_logit = logits[NativeApi.llama_token_nl()]; int lastTokensCount = lastTokens.Count(); var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), (ulong)last_n_repeat, repeatPenalty); SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), (ulong)last_n_repeat, alphaFrequency, alphaPresence); if (!penalizeNL) { logits[NativeApi.llama_token_nl()] = nl_logit; } return candidates_p; } /// /// /// /// /// /// The updated `pastTokensCount`. /// public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) { int total = tokens.Length; for(int i = 0; i < total; i += Params.BatchSize) { int n_eval = total - i; if(n_eval > Params.BatchSize) { n_eval = Params.BatchSize; } if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) { _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); } pastTokensCount += n_eval; } return pastTokensCount; } // TODO: add comment internal IEnumerable GenerateResult(IEnumerable ids) { foreach(var id in ids) { yield return Utils.TokenToString(id, _ctx, _encoding); } } /// /// /// public void Dispose() { _ctx.Dispose(); } } }