| @@ -3,17 +3,17 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| namespace LLama.Examples | |||
| { | |||
| public class ChatSession | |||
| { | |||
| LLama.Old.ChatSession<LLama.Old.LLamaModel> _session; | |||
| LLama.OldVersion.ChatSession<LLama.OldVersion.LLamaModel> _session; | |||
| public ChatSession(string modelPath, string promptFilePath, string[] antiprompt) | |||
| { | |||
| LLama.Old.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false)); | |||
| _session = new ChatSession<LLama.Old.LLamaModel>(model) | |||
| LLama.OldVersion.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false)); | |||
| _session = new ChatSession<LLama.OldVersion.LLamaModel>(model) | |||
| .WithPromptFile(promptFilePath) | |||
| .WithAntiprompt(antiprompt); | |||
| } | |||
| @@ -3,16 +3,16 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| namespace LLama.Examples.Old | |||
| { | |||
| public class ChatWithLLamaModel | |||
| { | |||
| LLama.Old.LLamaModel _model; | |||
| LLama.OldVersion.LLamaModel _model; | |||
| public ChatWithLLamaModel(string modelPath, string promptFilePath, string[] antiprompt) | |||
| { | |||
| _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(), | |||
| _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(), | |||
| repeat_penalty: 1.0f)).WithPromptFile(promptFilePath); | |||
| } | |||
| @@ -3,16 +3,16 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| namespace LLama.Examples | |||
| { | |||
| public class GetEmbeddings | |||
| { | |||
| LLamaEmbedder _embedder; | |||
| LLama.OldVersion.LLamaEmbedder _embedder; | |||
| public GetEmbeddings(string modelPath) | |||
| { | |||
| _embedder = new LLamaEmbedder(new LLamaParams(model: modelPath)); | |||
| _embedder = new LLama.OldVersion.LLamaEmbedder(new LLamaParams(model: modelPath)); | |||
| } | |||
| public void Run(string text) | |||
| @@ -3,16 +3,16 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| namespace LLama.Examples.Old | |||
| { | |||
| public class InstructMode | |||
| { | |||
| LLama.Old.LLamaModel _model; | |||
| LLama.OldVersion.LLamaModel _model; | |||
| public InstructMode(string modelPath, string promptFile) | |||
| { | |||
| _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, | |||
| _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, | |||
| repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPromptFile(promptFile); | |||
| } | |||
| @@ -3,16 +3,16 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| namespace LLama.Examples | |||
| { | |||
| public class SaveAndLoadState: IDisposable | |||
| { | |||
| LLama.Old.LLamaModel _model; | |||
| LLama.OldVersion.LLamaModel _model; | |||
| public SaveAndLoadState(string modelPath, string prompt) | |||
| { | |||
| _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, | |||
| _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, | |||
| repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPrompt(prompt); | |||
| } | |||
| @@ -26,12 +26,12 @@ if(version == 1) | |||
| Console.WriteLine("The examples for new versions are under working now. We'll soon update the examples." + | |||
| " Thank you for your support!"); | |||
| string modelPath = "D:\\development\\llama\\weights\\wizard-vicuna-13B.ggmlv3.q4_1.bin"; | |||
| var prompt = File.ReadAllText("Assets/dan.txt").Trim(); | |||
| LLamaInstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); | |||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | |||
| LLamaInteractExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337))); | |||
| while (true) | |||
| { | |||
| foreach (var text in ex.Infer(prompt, new SessionParams() { Temperature = 0.6f })) | |||
| await foreach (var text in ex.InferAsync(prompt, new SessionParams() { Temperature = 0.6f, AntiPrompts = new List<string>{ "user:" } }, default(CancellationToken))) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| using LLama.WebAPI.Models; | |||
| namespace LLama.WebAPI.Services; | |||
| @@ -20,6 +20,11 @@ namespace LLama.Abstractions.Params | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| public Dictionary<llama_token, float>? LogitBias { get; set; } = null; | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// path to file for saving/loading model eval state | |||
| /// </summary> | |||
| @@ -2,11 +2,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Threading; | |||
| namespace LLama | |||
| { | |||
| public interface ILLamaExecutor | |||
| { | |||
| IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null); | |||
| IEnumerable<string> Infer(string text, SessionParams? sessionParams = null); | |||
| IAsyncEnumerable<string> InferAsync(string text, SessionParams? sessionParams = null, CancellationToken token = default); | |||
| } | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| using LLama.Native; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Abstractions.Params; | |||
| using System.Linq; | |||
| namespace LLama | |||
| { | |||
| public class LLamaEmbedder : IDisposable | |||
| { | |||
| SafeLLamaContextHandle _ctx; | |||
| /// <summary> | |||
| /// Warning: must ensure the original model has params.embedding = true; | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| internal LLamaEmbedder(SafeLLamaContextHandle ctx) | |||
| { | |||
| _ctx = ctx; | |||
| } | |||
| public LLamaEmbedder(ModelParams @params) | |||
| { | |||
| @params.EmbeddingMode = true; | |||
| _ctx = Utils.InitLLamaContextFromModelParams(@params); | |||
| } | |||
| /// <summary> | |||
| /// Get the embeddings of the text. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="threads">Threads used for inference.</param> | |||
| /// <param name="addBos">Add bos to the text.</param> | |||
| /// <param name="encoding"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") | |||
| { | |||
| if (threads == -1) | |||
| { | |||
| threads = Math.Max(Environment.ProcessorCount / 2, 1); | |||
| } | |||
| int n_past = 0; | |||
| if (addBos) | |||
| { | |||
| text = text.Insert(0, " "); | |||
| } | |||
| var embed_inp = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)); | |||
| // TODO(Rinne): deal with log of prompt | |||
| if (embed_inp.Count() > 0) | |||
| { | |||
| var embed_inp_array = embed_inp.ToArray(); | |||
| if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) | |||
| { | |||
| throw new RuntimeError("Failed to eval."); | |||
| } | |||
| } | |||
| int n_embed = NativeApi.llama_n_embd(_ctx); | |||
| var embeddings = NativeApi.llama_get_embeddings(_ctx); | |||
| if (embeddings == null) | |||
| { | |||
| return new float[0]; | |||
| } | |||
| var span = new Span<float>(embeddings, n_embed); | |||
| float[] res = new float[n_embed]; | |||
| span.CopyTo(res.AsSpan()); | |||
| return res; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| _ctx.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,10 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -106,6 +109,121 @@ namespace LLama | |||
| } | |||
| } | |||
| public abstract IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null); | |||
| protected abstract bool GetLoopCondition(InferStateArgs args); | |||
| protected abstract void PreprocessInputs(string text, InferStateArgs args); | |||
| protected abstract bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); | |||
| protected abstract void InferInternal(SessionParams sessionParams, InferStateArgs args); | |||
| public virtual IEnumerable<string> Infer(string text, SessionParams? sessionParams = null) | |||
| { | |||
| if (sessionParams is null) | |||
| { | |||
| sessionParams = new SessionParams(); | |||
| } | |||
| InferStateArgs args = new InferStateArgs() | |||
| { | |||
| Antiprompts = sessionParams.AntiPrompts, | |||
| RemainedTokens = sessionParams.ResponseTokensCount, | |||
| ReturnValue = false, | |||
| WaitForInput = false, | |||
| NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count | |||
| }; | |||
| PreprocessInputs(text, args); | |||
| while (GetLoopCondition(args)) | |||
| { | |||
| InferInternal(sessionParams, args); | |||
| if (args.ReturnValue) | |||
| { | |||
| foreach (var item in _model.GenerateResult(_embeds)) | |||
| { | |||
| yield return item; | |||
| } | |||
| } | |||
| var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs); | |||
| if (extraOutputs is not null) | |||
| { | |||
| foreach (var item in extraOutputs) | |||
| { | |||
| yield return item; | |||
| } | |||
| } | |||
| if (breakGeneration) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, SessionParams? sessionParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| cancellationToken.ThrowIfCancellationRequested(); | |||
| // make this delay only to make the async method consistent with what it's expected to be | |||
| //await Task.Delay(1); | |||
| if (sessionParams is null) | |||
| { | |||
| sessionParams = new SessionParams(); | |||
| } | |||
| InferStateArgs args = new InferStateArgs() | |||
| { | |||
| Antiprompts = sessionParams.AntiPrompts, | |||
| RemainedTokens = sessionParams.ResponseTokensCount, | |||
| ReturnValue = false, | |||
| WaitForInput = false, | |||
| NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count | |||
| }; | |||
| PreprocessInputs(text, args); | |||
| while (GetLoopCondition(args)) | |||
| { | |||
| if (cancellationToken.IsCancellationRequested) | |||
| { | |||
| break; | |||
| } | |||
| InferInternal(sessionParams, args); | |||
| if (args.ReturnValue) | |||
| { | |||
| foreach (var item in _model.GenerateResult(_embeds)) | |||
| { | |||
| yield return item; | |||
| } | |||
| } | |||
| var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs); | |||
| if (extraOutputs is not null) | |||
| { | |||
| foreach (var item in extraOutputs) | |||
| { | |||
| yield return item; | |||
| } | |||
| } | |||
| if (breakGeneration) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// State arguments that are used in single inference | |||
| /// </summary> | |||
| protected class InferStateArgs | |||
| { | |||
| public IList<string>? Antiprompts { get; set; } | |||
| /// <summary> | |||
| /// Tokens count remained to be used. (n_remain) | |||
| /// </summary> | |||
| public int RemainedTokens { get; set; } | |||
| public bool ReturnValue { get; set; } | |||
| public bool WaitForInput { get; set; } | |||
| public bool NeedToSaveSession { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -11,28 +11,28 @@ namespace LLama | |||
| public class LLamaInstructExecutor : LLamaExecutorBase | |||
| { | |||
| bool _prompt_run = true; | |||
| readonly IEnumerable<llama_token> _llama_token_newline; | |||
| readonly IEnumerable<llama_token> _inp_pfx; | |||
| readonly IEnumerable<llama_token> _inp_sfx; | |||
| public LLamaInstructExecutor(LLamaModel model, string inputPrefix = "\n\n### Instruction:\n\n", | |||
| string inputSuffix = "\n\n### Response:\n\n") : base(model) | |||
| { | |||
| _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding); | |||
| _inp_pfx = _model.Tokenize(inputPrefix, true); | |||
| _inp_sfx = _model.Tokenize(inputSuffix, false); | |||
| } | |||
| /// <summary> | |||
| /// process the text and return the tokens consumed. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="sessionParams"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <param name="is_antiprompt"></param> | |||
| /// <returns></returns> | |||
| protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams) | |||
| protected override bool GetLoopCondition(InferStateArgs args) | |||
| { | |||
| if (text.Length > 1) | |||
| return args.RemainedTokens != 0 || _prompt_run; | |||
| } | |||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||
| { | |||
| if (_prompt_run) | |||
| { | |||
| // When running the first input (prompt) in inteactive mode, we should specially process it. | |||
| text = " " + text; | |||
| _embed_inps = _model.Tokenize(text, true).ToList(); | |||
| } | |||
| else | |||
| { | |||
| if (!text.EndsWith("\n")) | |||
| { | |||
| @@ -46,153 +46,120 @@ namespace LLama | |||
| _embed_inps.AddRange(_inp_sfx); | |||
| return line_inp.Count(); | |||
| } | |||
| else | |||
| { | |||
| return 0; | |||
| args.RemainedTokens -= line_inp.Count(); | |||
| } | |||
| } | |||
| public override IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null) | |||
| protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| { | |||
| if (sessionParams is null) | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| sessionParams = new SessionParams(); | |||
| } | |||
| // if n_remain < 0, the response will be generated endlessly. | |||
| int n_remain = sessionParams.ResponseTokensCount; | |||
| bool return_value = false; | |||
| bool wait_for_input = false; | |||
| bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count; | |||
| if (_prompt_run) | |||
| { | |||
| // When running the first input (prompt) in inteactive mode, we should specially process it. | |||
| text = " " + text; | |||
| _embed_inps = _model.Tokenize(text, true).ToList(); | |||
| } | |||
| else | |||
| { | |||
| n_remain -= ProcessTextBeforeInfer(text, sessionParams); | |||
| } | |||
| while (n_remain != 0 || _prompt_run) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||
| { | |||
| _prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| HandleRunOutOfContext(sessionParams.TokensToKeep); | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| } | |||
| TryReuseMathingPrefix(); | |||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | |||
| foreach (var antiprompt in args.Antiprompts) | |||
| { | |||
| _session_tokens.AddRange(_embeds); | |||
| _n_session_consumed = _session_tokens.Count; | |||
| if (last_output.EndsWith(antiprompt)) | |||
| { | |||
| args.WaitForInput = true; | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| _embeds.Clear(); | |||
| if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input) | |||
| if (_pastTokensCount > 0 && args.WaitForInput) | |||
| { | |||
| var temp = sessionParams.Temperature; | |||
| var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; | |||
| var top_p = sessionParams.TopK; | |||
| var tfs_z = sessionParams.TfsZ; | |||
| var typical_p = sessionParams.TypicalP; | |||
| var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; | |||
| var repeat_penalty = sessionParams.RepeatPenalty; | |||
| var alpha_presence = sessionParams.PresencePenalty; | |||
| var alpha_frequency = sessionParams.FrequencyPenalty; | |||
| var mirostat = sessionParams.Mirostat; | |||
| var mirostat_tau = sessionParams.MirostatTau; | |||
| var mirostat_eta = sessionParams.MirostatEta; | |||
| var penalize_nl = sessionParams.PenalizeNL; | |||
| // optionally save the session on first sample (for faster prompt loading next time) | |||
| if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session) | |||
| { | |||
| need_to_save_session = false; | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, | |||
| repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); | |||
| extraOutputs = new string[] { "\n> " }; | |||
| return true; | |||
| } | |||
| } | |||
| var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, | |||
| tfs_z, typical_p); | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||
| { | |||
| args.WaitForInput = true; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1) | |||
| { | |||
| args.RemainedTokens = sessionParams.ResponseTokensCount; | |||
| args.WaitForInput = true; | |||
| } | |||
| return false; | |||
| } | |||
| protected override void InferInternal(SessionParams sessionParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| _prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||
| { | |||
| HandleRunOutOfContext(sessionParams.TokensToKeep); | |||
| } | |||
| _embeds.Add(id); | |||
| TryReuseMathingPrefix(); | |||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||
| n_remain--; | |||
| return_value = true; | |||
| } | |||
| else | |||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | |||
| { | |||
| while (_embed_inps.Count > _consumedTokensCount) | |||
| { | |||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | |||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | |||
| _consumedTokensCount++; | |||
| if (_embeds.Count >= _model.Params.BatchSize) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| _session_tokens.AddRange(_embeds); | |||
| _n_session_consumed = _session_tokens.Count; | |||
| } | |||
| } | |||
| _embeds.Clear(); | |||
| if (return_value) | |||
| if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | |||
| { | |||
| var temp = sessionParams.Temperature; | |||
| var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; | |||
| var top_p = sessionParams.TopK; | |||
| var tfs_z = sessionParams.TfsZ; | |||
| var typical_p = sessionParams.TypicalP; | |||
| var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; | |||
| var repeat_penalty = sessionParams.RepeatPenalty; | |||
| var alpha_presence = sessionParams.PresencePenalty; | |||
| var alpha_frequency = sessionParams.FrequencyPenalty; | |||
| var mirostat = sessionParams.Mirostat; | |||
| var mirostat_tau = sessionParams.MirostatTau; | |||
| var mirostat_eta = sessionParams.MirostatEta; | |||
| var penalize_nl = sessionParams.PenalizeNL; | |||
| // optionally save the session on first sample (for faster prompt loading next time) | |||
| if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | |||
| { | |||
| foreach (var item in _model.GenerateResult(_embeds)) | |||
| { | |||
| yield return item; | |||
| } | |||
| args.NeedToSaveSession = false; | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (antiprompts is not null && antiprompts.Count() > 0) | |||
| { | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| } | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, | |||
| repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); | |||
| foreach (var antiprompt in antiprompts) | |||
| { | |||
| if (last_output.EndsWith(antiprompt)) | |||
| { | |||
| wait_for_input = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, | |||
| tfs_z, typical_p); | |||
| if (_pastTokensCount > 0 && wait_for_input) | |||
| { | |||
| yield return "\n> "; | |||
| break; | |||
| } | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||
| { | |||
| wait_for_input = true; | |||
| } | |||
| _embeds.Add(id); | |||
| if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1) | |||
| args.RemainedTokens--; | |||
| args.ReturnValue = true; | |||
| } | |||
| else | |||
| { | |||
| while (_embed_inps.Count > _consumedTokensCount) | |||
| { | |||
| n_remain = sessionParams.ResponseTokensCount; | |||
| wait_for_input = true; | |||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | |||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | |||
| _consumedTokensCount++; | |||
| if (_embeds.Count >= _model.Params.BatchSize) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -3,7 +3,10 @@ using LLama.Native; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -22,43 +25,16 @@ namespace LLama | |||
| } | |||
| /// <summary> | |||
| /// process the text and return the tokens consumed. | |||
| /// Define whether to continue the loop to generate responses. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="sessionParams"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <param name="is_antiprompt"></param> | |||
| /// <returns></returns> | |||
| protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams) | |||
| protected override bool GetLoopCondition(InferStateArgs args) | |||
| { | |||
| if (text.Length > 1) | |||
| { | |||
| if (!text.EndsWith("\n")) | |||
| { | |||
| text += "\n"; | |||
| } | |||
| var line_inp = _model.Tokenize(text, false); | |||
| _embed_inps.AddRange(line_inp); | |||
| return line_inp.Count(); | |||
| } | |||
| else | |||
| { | |||
| return 0; | |||
| } | |||
| return args.RemainedTokens != 0 && !args.WaitForInput || _prompt_run; | |||
| } | |||
| public override IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null) | |||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||
| { | |||
| if (sessionParams is null) | |||
| { | |||
| sessionParams = new SessionParams(); | |||
| } | |||
| // if n_remain < 0, the response will be generated endlessly. | |||
| int n_remain = sessionParams.ResponseTokensCount; | |||
| bool return_value = false; | |||
| bool wait_for_input = false; | |||
| bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count; | |||
| if (_prompt_run) | |||
| { | |||
| // When running the first input (prompt) in inteactive mode, we should specially process it. | |||
| @@ -67,135 +43,143 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| n_remain -= ProcessTextBeforeInfer(text, sessionParams); | |||
| if (!text.EndsWith("\n")) | |||
| { | |||
| text += "\n"; | |||
| } | |||
| var line_inp = _model.Tokenize(text, false); | |||
| _embed_inps.AddRange(line_inp); | |||
| args.RemainedTokens -= line_inp.Count(); | |||
| } | |||
| } | |||
| while (n_remain != 0 && !wait_for_input || _prompt_run) | |||
| /// <summary> | |||
| /// Return whether to break the generation. | |||
| /// </summary> | |||
| /// <param name="args"></param> | |||
| /// <returns></returns> | |||
| protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| { | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||
| { | |||
| _prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| HandleRunOutOfContext(sessionParams.TokensToKeep); | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| } | |||
| TryReuseMathingPrefix(); | |||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | |||
| foreach (var antiprompt in args.Antiprompts) | |||
| { | |||
| _session_tokens.AddRange(_embeds); | |||
| _n_session_consumed = _session_tokens.Count; | |||
| if (last_output.EndsWith(antiprompt)) | |||
| { | |||
| args.WaitForInput = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| _embeds.Clear(); | |||
| if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input) | |||
| if (_pastTokensCount > 0 && args.WaitForInput) | |||
| { | |||
| var temp = sessionParams.Temperature; | |||
| var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; | |||
| var top_p = sessionParams.TopK; | |||
| var tfs_z = sessionParams.TfsZ; | |||
| var typical_p = sessionParams.TypicalP; | |||
| var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; | |||
| var repeat_penalty = sessionParams.RepeatPenalty; | |||
| var alpha_presence = sessionParams.PresencePenalty; | |||
| var alpha_frequency = sessionParams.FrequencyPenalty; | |||
| var mirostat = sessionParams.Mirostat; | |||
| var mirostat_tau = sessionParams.MirostatTau; | |||
| var mirostat_eta = sessionParams.MirostatEta; | |||
| var penalize_nl = sessionParams.PenalizeNL; | |||
| // optionally save the session on first sample (for faster prompt loading next time) | |||
| if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session) | |||
| { | |||
| need_to_save_session = false; | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, | |||
| repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); | |||
| return true; | |||
| } | |||
| } | |||
| var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, | |||
| tfs_z, typical_p); | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||
| { | |||
| extraOutputs = new string[] { " [end of text]\n" }; | |||
| return true; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1) | |||
| { | |||
| args.RemainedTokens = sessionParams.ResponseTokensCount; | |||
| args.WaitForInput = true; | |||
| } | |||
| return false; | |||
| } | |||
| if (id == NativeApi.llama_token_eos()) | |||
| { | |||
| id = _llama_token_newline.First(); | |||
| if (antiprompts is not null && antiprompts.Count() > 0) | |||
| { | |||
| var first_antiprompt = _model.Tokenize(antiprompts.First(), false); | |||
| _embed_inps.AddRange(first_antiprompt); | |||
| } | |||
| } | |||
| protected override void InferInternal(SessionParams sessionParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| _prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > _model.ContextSize) | |||
| { | |||
| HandleRunOutOfContext(sessionParams.TokensToKeep); | |||
| } | |||
| _embeds.Add(id); | |||
| TryReuseMathingPrefix(); | |||
| _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); | |||
| n_remain--; | |||
| return_value = true; | |||
| } | |||
| else | |||
| if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) | |||
| { | |||
| while (_embed_inps.Count > _consumedTokensCount) | |||
| { | |||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | |||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | |||
| _consumedTokensCount++; | |||
| if (_embeds.Count >= _model.Params.BatchSize) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| _session_tokens.AddRange(_embeds); | |||
| _n_session_consumed = _session_tokens.Count; | |||
| } | |||
| } | |||
| _embeds.Clear(); | |||
| if (return_value) | |||
| if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) | |||
| { | |||
| var temp = sessionParams.Temperature; | |||
| var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; | |||
| var top_p = sessionParams.TopK; | |||
| var tfs_z = sessionParams.TfsZ; | |||
| var typical_p = sessionParams.TypicalP; | |||
| var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; | |||
| var repeat_penalty = sessionParams.RepeatPenalty; | |||
| var alpha_presence = sessionParams.PresencePenalty; | |||
| var alpha_frequency = sessionParams.FrequencyPenalty; | |||
| var mirostat = sessionParams.Mirostat; | |||
| var mirostat_tau = sessionParams.MirostatTau; | |||
| var mirostat_eta = sessionParams.MirostatEta; | |||
| var penalize_nl = sessionParams.PenalizeNL; | |||
| // optionally save the session on first sample (for faster prompt loading next time) | |||
| if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) | |||
| { | |||
| foreach (var item in _model.GenerateResult(_embeds)) | |||
| { | |||
| yield return item; | |||
| } | |||
| args.NeedToSaveSession = false; | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (antiprompts is not null && antiprompts.Count() > 0) | |||
| { | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| } | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, | |||
| repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); | |||
| foreach (var antiprompt in antiprompts) | |||
| { | |||
| if (last_output.EndsWith(antiprompt)) | |||
| { | |||
| wait_for_input = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, | |||
| tfs_z, typical_p); | |||
| if (_pastTokensCount > 0 && wait_for_input) | |||
| _last_n_tokens.Enqueue(id); | |||
| if (id == NativeApi.llama_token_eos()) | |||
| { | |||
| id = _llama_token_newline.First(); | |||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||
| { | |||
| break; | |||
| var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); | |||
| _embed_inps.AddRange(first_antiprompt); | |||
| } | |||
| } | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) | |||
| { | |||
| yield return " [end of text]\n"; | |||
| break; | |||
| } | |||
| _embeds.Add(id); | |||
| if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1) | |||
| args.RemainedTokens--; | |||
| args.ReturnValue = true; | |||
| } | |||
| else | |||
| { | |||
| while (_embed_inps.Count > _consumedTokensCount) | |||
| { | |||
| n_remain = sessionParams.ResponseTokensCount; | |||
| wait_for_input = true; | |||
| _embeds.Add(_embed_inps[_consumedTokensCount]); | |||
| _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); | |||
| _consumedTokensCount++; | |||
| if (_embeds.Count >= _model.Params.BatchSize) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,7 +1,7 @@ | |||
| using LLama.Abstractions.Params; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Old; | |||
| using LLama.OldVersion; | |||
| using LLama.Types; | |||
| using LLama.Extensions; | |||
| using System; | |||
| @@ -8,7 +8,7 @@ | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <AllowUnsafeBlocks>True</AllowUnsafeBlocks> | |||
| <Version>0.3.0</Version> | |||
| <Version>0.4.0</Version> | |||
| <Authors>Yaohui Liu, Haiping Chen</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| @@ -21,7 +21,7 @@ | |||
| The .NET binding of LLama.cpp, providing APIs to run the model and deploy it on Web. For model weights to run, please go to https://github.com/SciSharp/LLamaSharp for more information. | |||
| </Description> | |||
| <PackageReleaseNotes> | |||
| LLamaSharp 0.3.0 supports loading and saving session state, tokenization and detokenization. Besides, since 0.3.0, `LLamaModelV1` is dropped. | |||
| LLamaSharp 0.4.0 supports better APIs than v0.3.0. Note that many break changes were made in this version. APIs of v0.3.0 were moved to LLama.Old namespace. | |||
| </PackageReleaseNotes> | |||
| <PackageLicenseExpression>MIT</PackageLicenseExpression> | |||
| <PackageOutputPath>packages</PackageOutputPath> | |||
| @@ -41,6 +41,7 @@ | |||
| <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'"> | |||
| <PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" /> | |||
| <PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" /> | |||
| <PackageReference Include="System.Linq.Async" VersionOverride="[6.0.1, )" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -3,7 +3,7 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| public class ChatSession<T> where T : IChatModel | |||
| { | |||
| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| public interface IChatModel | |||
| { | |||
| @@ -4,7 +4,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| public class LLamaEmbedder : IDisposable | |||
| { | |||
| @@ -9,7 +9,7 @@ using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| using llama_token = Int32; | |||
| public class LLamaModel : IChatModel, IDisposable | |||
| @@ -1,7 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| using llama_token = Int32; | |||
| public struct LLamaParams | |||
| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| public enum ChatRole | |||
| { | |||
| @@ -8,7 +8,7 @@ using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.IO; | |||
| namespace LLama.Old | |||
| namespace LLama.OldVersion | |||
| { | |||
| using llama_token = Int32; | |||
| internal static class Utils | |||