diff --git a/LLama.Examples/Old/ChatSession.cs b/LLama.Examples/Old/ChatSession.cs index 94b18930..185504fe 100644 --- a/LLama.Examples/Old/ChatSession.cs +++ b/LLama.Examples/Old/ChatSession.cs @@ -3,17 +3,17 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using LLama.Old; +using LLama.OldVersion; namespace LLama.Examples { public class ChatSession { - LLama.Old.ChatSession _session; + LLama.OldVersion.ChatSession _session; public ChatSession(string modelPath, string promptFilePath, string[] antiprompt) { - LLama.Old.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false)); - _session = new ChatSession(model) + LLama.OldVersion.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false)); + _session = new ChatSession(model) .WithPromptFile(promptFilePath) .WithAntiprompt(antiprompt); } diff --git a/LLama.Examples/Old/ChatWithLLamaModel.cs b/LLama.Examples/Old/ChatWithLLamaModel.cs index 87ccbb4c..452b5b2d 100644 --- a/LLama.Examples/Old/ChatWithLLamaModel.cs +++ b/LLama.Examples/Old/ChatWithLLamaModel.cs @@ -3,16 +3,16 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using LLama.Old; +using LLama.OldVersion; namespace LLama.Examples.Old { public class ChatWithLLamaModel { - LLama.Old.LLamaModel _model; + LLama.OldVersion.LLamaModel _model; public ChatWithLLamaModel(string modelPath, string promptFilePath, string[] antiprompt) { - _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(), + _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(), repeat_penalty: 1.0f)).WithPromptFile(promptFilePath); } diff --git a/LLama.Examples/Old/GetEmbeddings.cs b/LLama.Examples/Old/GetEmbeddings.cs index 8308bd21..a9bf56d9 100644 --- a/LLama.Examples/Old/GetEmbeddings.cs +++ b/LLama.Examples/Old/GetEmbeddings.cs @@ -3,16 +3,16 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using LLama.Old; +using LLama.OldVersion; namespace LLama.Examples { public class GetEmbeddings { - LLamaEmbedder _embedder; + LLama.OldVersion.LLamaEmbedder _embedder; public GetEmbeddings(string modelPath) { - _embedder = new LLamaEmbedder(new LLamaParams(model: modelPath)); + _embedder = new LLama.OldVersion.LLamaEmbedder(new LLamaParams(model: modelPath)); } public void Run(string text) diff --git a/LLama.Examples/Old/InstructMode.cs b/LLama.Examples/Old/InstructMode.cs index 880cee2b..2b954e3f 100644 --- a/LLama.Examples/Old/InstructMode.cs +++ b/LLama.Examples/Old/InstructMode.cs @@ -3,16 +3,16 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using LLama.Old; +using LLama.OldVersion; namespace LLama.Examples.Old { public class InstructMode { - LLama.Old.LLamaModel _model; + LLama.OldVersion.LLamaModel _model; public InstructMode(string modelPath, string promptFile) { - _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, + _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPromptFile(promptFile); } diff --git a/LLama.Examples/Old/SaveAndLoadState.cs b/LLama.Examples/Old/SaveAndLoadState.cs index fab3d234..e566a1ec 100644 --- a/LLama.Examples/Old/SaveAndLoadState.cs +++ b/LLama.Examples/Old/SaveAndLoadState.cs @@ -3,16 +3,16 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using LLama.Old; +using LLama.OldVersion; namespace LLama.Examples { public class SaveAndLoadState: IDisposable { - LLama.Old.LLamaModel _model; + LLama.OldVersion.LLamaModel _model; public SaveAndLoadState(string modelPath, string prompt) { - _model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, + _model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true, repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPrompt(prompt); } diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index d6ebd644..819ed38c 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -26,12 +26,12 @@ if(version == 1) Console.WriteLine("The examples for new versions are under working now. We'll soon update the examples." + " Thank you for your support!"); string modelPath = "D:\\development\\llama\\weights\\wizard-vicuna-13B.ggmlv3.q4_1.bin"; - var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - LLamaInstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); + var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + LLamaInteractExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337))); while (true) { - foreach (var text in ex.Infer(prompt, new SessionParams() { Temperature = 0.6f })) + await foreach (var text in ex.InferAsync(prompt, new SessionParams() { Temperature = 0.6f, AntiPrompts = new List{ "user:" } }, default(CancellationToken))) { Console.Write(text); } diff --git a/LLama.WebAPI/Services/ChatService.cs b/LLama.WebAPI/Services/ChatService.cs index bdedec79..e457e3c2 100644 --- a/LLama.WebAPI/Services/ChatService.cs +++ b/LLama.WebAPI/Services/ChatService.cs @@ -1,4 +1,4 @@ -using LLama.Old; +using LLama.OldVersion; using LLama.WebAPI.Models; namespace LLama.WebAPI.Services; diff --git a/LLama/Abstractions/Params/SessionParams.cs b/LLama/Abstractions/Params/SessionParams.cs index 41a28c21..a18b3b12 100644 --- a/LLama/Abstractions/Params/SessionParams.cs +++ b/LLama/Abstractions/Params/SessionParams.cs @@ -20,6 +20,11 @@ namespace LLama.Abstractions.Params /// logit bias for specific tokens /// public Dictionary? LogitBias { get; set; } = null; + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IList AntiPrompts { get; set; } = Array.Empty(); /// /// path to file for saving/loading model eval state /// diff --git a/LLama/ILLamaExecutor.cs b/LLama/ILLamaExecutor.cs index 4e773637..89f9e45e 100644 --- a/LLama/ILLamaExecutor.cs +++ b/LLama/ILLamaExecutor.cs @@ -2,11 +2,14 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; namespace LLama { public interface ILLamaExecutor { - IEnumerable Infer(string text, SessionParams? sessionParams = null, IEnumerable? antiprompts = null); + IEnumerable Infer(string text, SessionParams? sessionParams = null); + + IAsyncEnumerable InferAsync(string text, SessionParams? sessionParams = null, CancellationToken token = default); } } diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs new file mode 100644 index 00000000..61b03fe9 --- /dev/null +++ b/LLama/LLamaEmbedder.cs @@ -0,0 +1,80 @@ +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Exceptions; +using LLama.Abstractions.Params; +using System.Linq; + +namespace LLama +{ + public class LLamaEmbedder : IDisposable + { + SafeLLamaContextHandle _ctx; + + /// + /// Warning: must ensure the original model has params.embedding = true; + /// + /// + internal LLamaEmbedder(SafeLLamaContextHandle ctx) + { + _ctx = ctx; + } + + public LLamaEmbedder(ModelParams @params) + { + @params.EmbeddingMode = true; + _ctx = Utils.InitLLamaContextFromModelParams(@params); + } + + /// + /// Get the embeddings of the text. + /// + /// + /// Threads used for inference. + /// Add bos to the text. + /// + /// + /// + public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") + { + if (threads == -1) + { + threads = Math.Max(Environment.ProcessorCount / 2, 1); + } + int n_past = 0; + if (addBos) + { + text = text.Insert(0, " "); + } + var embed_inp = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)); + + // TODO(Rinne): deal with log of prompt + + if (embed_inp.Count() > 0) + { + var embed_inp_array = embed_inp.ToArray(); + if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) + { + throw new RuntimeError("Failed to eval."); + } + } + + int n_embed = NativeApi.llama_n_embd(_ctx); + var embeddings = NativeApi.llama_get_embeddings(_ctx); + if (embeddings == null) + { + return new float[0]; + } + var span = new Span(embeddings, n_embed); + float[] res = new float[n_embed]; + span.CopyTo(res.AsSpan()); + return res; + } + + public void Dispose() + { + _ctx.Dispose(); + } + } +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 61d9a192..57763cdc 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -6,7 +6,10 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace LLama { @@ -106,6 +109,121 @@ namespace LLama } } - public abstract IEnumerable Infer(string text, SessionParams? sessionParams = null, IEnumerable? antiprompts = null); + protected abstract bool GetLoopCondition(InferStateArgs args); + protected abstract void PreprocessInputs(string text, InferStateArgs args); + protected abstract bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable? extraOutputs); + protected abstract void InferInternal(SessionParams sessionParams, InferStateArgs args); + public virtual IEnumerable Infer(string text, SessionParams? sessionParams = null) + { + if (sessionParams is null) + { + sessionParams = new SessionParams(); + } + + InferStateArgs args = new InferStateArgs() + { + Antiprompts = sessionParams.AntiPrompts, + RemainedTokens = sessionParams.ResponseTokensCount, + ReturnValue = false, + WaitForInput = false, + NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count + }; + + PreprocessInputs(text, args); + + while (GetLoopCondition(args)) + { + InferInternal(sessionParams, args); + + if (args.ReturnValue) + { + foreach (var item in _model.GenerateResult(_embeds)) + { + yield return item; + } + } + + var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs); + if (extraOutputs is not null) + { + foreach (var item in extraOutputs) + { + yield return item; + } + } + if (breakGeneration) + { + break; + } + } + } + public virtual async IAsyncEnumerable InferAsync(string text, SessionParams? sessionParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + // make this delay only to make the async method consistent with what it's expected to be + //await Task.Delay(1); + + if (sessionParams is null) + { + sessionParams = new SessionParams(); + } + + InferStateArgs args = new InferStateArgs() + { + Antiprompts = sessionParams.AntiPrompts, + RemainedTokens = sessionParams.ResponseTokensCount, + ReturnValue = false, + WaitForInput = false, + NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count + }; + + PreprocessInputs(text, args); + + while (GetLoopCondition(args)) + { + if (cancellationToken.IsCancellationRequested) + { + break; + } + + InferInternal(sessionParams, args); + + if (args.ReturnValue) + { + foreach (var item in _model.GenerateResult(_embeds)) + { + yield return item; + } + } + + var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs); + if (extraOutputs is not null) + { + foreach (var item in extraOutputs) + { + yield return item; + } + } + if (breakGeneration) + { + break; + } + } + } + + /// + /// State arguments that are used in single inference + /// + protected class InferStateArgs + { + public IList? Antiprompts { get; set; } + /// + /// Tokens count remained to be used. (n_remain) + /// + public int RemainedTokens { get; set; } + public bool ReturnValue { get; set; } + public bool WaitForInput { get; set; } + public bool NeedToSaveSession { get; set; } + } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 220c2182..7819a5ee 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -11,28 +11,28 @@ namespace LLama public class LLamaInstructExecutor : LLamaExecutorBase { bool _prompt_run = true; - readonly IEnumerable _llama_token_newline; readonly IEnumerable _inp_pfx; readonly IEnumerable _inp_sfx; public LLamaInstructExecutor(LLamaModel model, string inputPrefix = "\n\n### Instruction:\n\n", string inputSuffix = "\n\n### Response:\n\n") : base(model) { - _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding); _inp_pfx = _model.Tokenize(inputPrefix, true); _inp_sfx = _model.Tokenize(inputSuffix, false); } - /// - /// process the text and return the tokens consumed. - /// - /// - /// - /// - /// - /// - protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams) + protected override bool GetLoopCondition(InferStateArgs args) { - if (text.Length > 1) + return args.RemainedTokens != 0 || _prompt_run; + } + protected override void PreprocessInputs(string text, InferStateArgs args) + { + if (_prompt_run) + { + // When running the first input (prompt) in inteactive mode, we should specially process it. + text = " " + text; + _embed_inps = _model.Tokenize(text, true).ToList(); + } + else { if (!text.EndsWith("\n")) { @@ -46,153 +46,120 @@ namespace LLama _embed_inps.AddRange(_inp_sfx); - return line_inp.Count(); - } - else - { - return 0; + args.RemainedTokens -= line_inp.Count(); } } - - public override IEnumerable Infer(string text, SessionParams? sessionParams = null, IEnumerable? antiprompts = null) + protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable? extraOutputs) { - if (sessionParams is null) + extraOutputs = null; + if (_embed_inps.Count <= _consumedTokensCount) { - sessionParams = new SessionParams(); - } - // if n_remain < 0, the response will be generated endlessly. - int n_remain = sessionParams.ResponseTokensCount; - bool return_value = false; - bool wait_for_input = false; - bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count; - - if (_prompt_run) - { - // When running the first input (prompt) in inteactive mode, we should specially process it. - text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); - } - else - { - n_remain -= ProcessTextBeforeInfer(text, sessionParams); - } - - while (n_remain != 0 || _prompt_run) - { - if (_embeds.Count > 0) + if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - _prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + string last_output = ""; + foreach (var id in _last_n_tokens) { - HandleRunOutOfContext(sessionParams.TokensToKeep); + last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); } - TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); - - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + foreach (var antiprompt in args.Antiprompts) { - _session_tokens.AddRange(_embeds); - _n_session_consumed = _session_tokens.Count; + if (last_output.EndsWith(antiprompt)) + { + args.WaitForInput = true; + return true; + } } } - _embeds.Clear(); - - if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input) + if (_pastTokensCount > 0 && args.WaitForInput) { - var temp = sessionParams.Temperature; - var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; - var top_p = sessionParams.TopK; - var tfs_z = sessionParams.TfsZ; - var typical_p = sessionParams.TypicalP; - var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; - var repeat_penalty = sessionParams.RepeatPenalty; - var alpha_presence = sessionParams.PresencePenalty; - var alpha_frequency = sessionParams.FrequencyPenalty; - var mirostat = sessionParams.Mirostat; - var mirostat_tau = sessionParams.MirostatTau; - var mirostat_eta = sessionParams.MirostatEta; - var penalize_nl = sessionParams.PenalizeNL; - - // optionally save the session on first sample (for faster prompt loading next time) - if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session) - { - need_to_save_session = false; - SaveSessionFile(_pathSession); - } - - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, - repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); + extraOutputs = new string[] { "\n> " }; + return true; + } + } - var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, - tfs_z, typical_p); + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) + { + args.WaitForInput = true; + } - _last_n_tokens.Enqueue(id); + if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1) + { + args.RemainedTokens = sessionParams.ResponseTokensCount; + args.WaitForInput = true; + } + return false; + } + protected override void InferInternal(SessionParams sessionParams, InferStateArgs args) + { + if (_embeds.Count > 0) + { + _prompt_run = false; + if (_pastTokensCount + _embeds.Count > _model.ContextSize) + { + HandleRunOutOfContext(sessionParams.TokensToKeep); + } - _embeds.Add(id); + TryReuseMathingPrefix(); + _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); - n_remain--; - return_value = true; - } - else + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { - while (_embed_inps.Count > _consumedTokensCount) - { - _embeds.Add(_embed_inps[_consumedTokensCount]); - _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); - _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) - { - break; - } - } + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; } + } + + _embeds.Clear(); - if (return_value) + if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) + { + var temp = sessionParams.Temperature; + var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; + var top_p = sessionParams.TopK; + var tfs_z = sessionParams.TfsZ; + var typical_p = sessionParams.TypicalP; + var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; + var repeat_penalty = sessionParams.RepeatPenalty; + var alpha_presence = sessionParams.PresencePenalty; + var alpha_frequency = sessionParams.FrequencyPenalty; + var mirostat = sessionParams.Mirostat; + var mirostat_tau = sessionParams.MirostatTau; + var mirostat_eta = sessionParams.MirostatEta; + var penalize_nl = sessionParams.PenalizeNL; + + // optionally save the session on first sample (for faster prompt loading next time) + if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { - foreach (var item in _model.GenerateResult(_embeds)) - { - yield return item; - } + args.NeedToSaveSession = false; + SaveSessionFile(_pathSession); } - if (_embed_inps.Count <= _consumedTokensCount) - { - if (antiprompts is not null && antiprompts.Count() > 0) - { - string last_output = ""; - foreach (var id in _last_n_tokens) - { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); - } + var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, + repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); - foreach (var antiprompt in antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - wait_for_input = true; - break; - } - } - } + var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, + tfs_z, typical_p); - if (_pastTokensCount > 0 && wait_for_input) - { - yield return "\n> "; - break; - } - } + _last_n_tokens.Enqueue(id); - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) - { - wait_for_input = true; - } + _embeds.Add(id); - if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1) + args.RemainedTokens--; + args.ReturnValue = true; + } + else + { + while (_embed_inps.Count > _consumedTokensCount) { - n_remain = sessionParams.ResponseTokensCount; - wait_for_input = true; + _embeds.Add(_embed_inps[_consumedTokensCount]); + _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); + _consumedTokensCount++; + if (_embeds.Count >= _model.Params.BatchSize) + { + break; + } } } } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index e2671467..8b30edd3 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -3,7 +3,10 @@ using LLama.Native; using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace LLama { @@ -22,43 +25,16 @@ namespace LLama } /// - /// process the text and return the tokens consumed. + /// Define whether to continue the loop to generate responses. /// - /// - /// - /// - /// /// - protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams) + protected override bool GetLoopCondition(InferStateArgs args) { - if (text.Length > 1) - { - if (!text.EndsWith("\n")) - { - text += "\n"; - } - var line_inp = _model.Tokenize(text, false); - _embed_inps.AddRange(line_inp); - return line_inp.Count(); - } - else - { - return 0; - } + return args.RemainedTokens != 0 && !args.WaitForInput || _prompt_run; } - public override IEnumerable Infer(string text, SessionParams? sessionParams = null, IEnumerable? antiprompts = null) + protected override void PreprocessInputs(string text, InferStateArgs args) { - if (sessionParams is null) - { - sessionParams = new SessionParams(); - } - // if n_remain < 0, the response will be generated endlessly. - int n_remain = sessionParams.ResponseTokensCount; - bool return_value = false; - bool wait_for_input = false; - bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count; - if (_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. @@ -67,135 +43,143 @@ namespace LLama } else { - n_remain -= ProcessTextBeforeInfer(text, sessionParams); + if (!text.EndsWith("\n")) + { + text += "\n"; + } + var line_inp = _model.Tokenize(text, false); + _embed_inps.AddRange(line_inp); + args.RemainedTokens -= line_inp.Count(); } + } - while (n_remain != 0 && !wait_for_input || _prompt_run) + /// + /// Return whether to break the generation. + /// + /// + /// + protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable? extraOutputs) + { + extraOutputs = null; + if (_embed_inps.Count <= _consumedTokensCount) { - if (_embeds.Count > 0) + if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - _prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + string last_output = ""; + foreach (var id in _last_n_tokens) { - HandleRunOutOfContext(sessionParams.TokensToKeep); + last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); } - TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); - - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + foreach (var antiprompt in args.Antiprompts) { - _session_tokens.AddRange(_embeds); - _n_session_consumed = _session_tokens.Count; + if (last_output.EndsWith(antiprompt)) + { + args.WaitForInput = true; + break; + } } } - _embeds.Clear(); - - if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input) + if (_pastTokensCount > 0 && args.WaitForInput) { - var temp = sessionParams.Temperature; - var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; - var top_p = sessionParams.TopK; - var tfs_z = sessionParams.TfsZ; - var typical_p = sessionParams.TypicalP; - var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; - var repeat_penalty = sessionParams.RepeatPenalty; - var alpha_presence = sessionParams.PresencePenalty; - var alpha_frequency = sessionParams.FrequencyPenalty; - var mirostat = sessionParams.Mirostat; - var mirostat_tau = sessionParams.MirostatTau; - var mirostat_eta = sessionParams.MirostatEta; - var penalize_nl = sessionParams.PenalizeNL; - - // optionally save the session on first sample (for faster prompt loading next time) - if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session) - { - need_to_save_session = false; - SaveSessionFile(_pathSession); - } - - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, - repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); + return true; + } + } - var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, - tfs_z, typical_p); + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) + { + extraOutputs = new string[] { " [end of text]\n" }; + return true; + } - _last_n_tokens.Enqueue(id); + if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1) + { + args.RemainedTokens = sessionParams.ResponseTokensCount; + args.WaitForInput = true; + } + return false; + } - if (id == NativeApi.llama_token_eos()) - { - id = _llama_token_newline.First(); - if (antiprompts is not null && antiprompts.Count() > 0) - { - var first_antiprompt = _model.Tokenize(antiprompts.First(), false); - _embed_inps.AddRange(first_antiprompt); - } - } + protected override void InferInternal(SessionParams sessionParams, InferStateArgs args) + { + if (_embeds.Count > 0) + { + _prompt_run = false; + if (_pastTokensCount + _embeds.Count > _model.ContextSize) + { + HandleRunOutOfContext(sessionParams.TokensToKeep); + } - _embeds.Add(id); + TryReuseMathingPrefix(); + _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); - n_remain--; - return_value = true; - } - else + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { - while (_embed_inps.Count > _consumedTokensCount) - { - _embeds.Add(_embed_inps[_consumedTokensCount]); - _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); - _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) - { - break; - } - } + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; } + } + + _embeds.Clear(); - if (return_value) + if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) + { + var temp = sessionParams.Temperature; + var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK; + var top_p = sessionParams.TopK; + var tfs_z = sessionParams.TfsZ; + var typical_p = sessionParams.TypicalP; + var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount; + var repeat_penalty = sessionParams.RepeatPenalty; + var alpha_presence = sessionParams.PresencePenalty; + var alpha_frequency = sessionParams.FrequencyPenalty; + var mirostat = sessionParams.Mirostat; + var mirostat_tau = sessionParams.MirostatTau; + var mirostat_eta = sessionParams.MirostatEta; + var penalize_nl = sessionParams.PenalizeNL; + + // optionally save the session on first sample (for faster prompt loading next time) + if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { - foreach (var item in _model.GenerateResult(_embeds)) - { - yield return item; - } + args.NeedToSaveSession = false; + SaveSessionFile(_pathSession); } - if (_embed_inps.Count <= _consumedTokensCount) - { - if (antiprompts is not null && antiprompts.Count() > 0) - { - string last_output = ""; - foreach (var id in _last_n_tokens) - { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); - } + var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n, + repeat_penalty, alpha_frequency, alpha_presence, penalize_nl); - foreach (var antiprompt in antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - wait_for_input = true; - break; - } - } - } + var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p, + tfs_z, typical_p); - if (_pastTokensCount > 0 && wait_for_input) + _last_n_tokens.Enqueue(id); + + if (id == NativeApi.llama_token_eos()) + { + id = _llama_token_newline.First(); + if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - break; + var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); + _embed_inps.AddRange(first_antiprompt); } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos()) - { - yield return " [end of text]\n"; - break; - } + _embeds.Add(id); - if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1) + args.RemainedTokens--; + args.ReturnValue = true; + } + else + { + while (_embed_inps.Count > _consumedTokensCount) { - n_remain = sessionParams.ResponseTokensCount; - wait_for_input = true; + _embeds.Add(_embed_inps[_consumedTokensCount]); + _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); + _consumedTokensCount++; + if (_embeds.Count >= _model.Params.BatchSize) + { + break; + } } } } diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 6c7219e5..221e5ce2 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -1,7 +1,7 @@ using LLama.Abstractions.Params; using LLama.Exceptions; using LLama.Native; -using LLama.Old; +using LLama.OldVersion; using LLama.Types; using LLama.Extensions; using System; diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 3d24cfe8..8a5035ee 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -8,7 +8,7 @@ AnyCPU;x64 True - 0.3.0 + 0.4.0 Yaohui Liu, Haiping Chen SciSharp STACK true @@ -21,7 +21,7 @@ The .NET binding of LLama.cpp, providing APIs to run the model and deploy it on Web. For model weights to run, please go to https://github.com/SciSharp/LLamaSharp for more information. - LLamaSharp 0.3.0 supports loading and saving session state, tokenization and detokenization. Besides, since 0.3.0, `LLamaModelV1` is dropped. + LLamaSharp 0.4.0 supports better APIs than v0.3.0. Note that many break changes were made in this version. APIs of v0.3.0 were moved to LLama.Old namespace. MIT packages @@ -41,6 +41,7 @@ + diff --git a/LLama/Old/ChatSession.cs b/LLama/OldVersion/ChatSession.cs similarity index 98% rename from LLama/Old/ChatSession.cs rename to LLama/OldVersion/ChatSession.cs index c1b4ca2d..d6e9bfc6 100644 --- a/LLama/Old/ChatSession.cs +++ b/LLama/OldVersion/ChatSession.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Text; -namespace LLama.Old +namespace LLama.OldVersion { public class ChatSession where T : IChatModel { diff --git a/LLama/Old/IChatModel.cs b/LLama/OldVersion/IChatModel.cs similarity index 95% rename from LLama/Old/IChatModel.cs rename to LLama/OldVersion/IChatModel.cs index 3324292a..7fbd898b 100644 --- a/LLama/Old/IChatModel.cs +++ b/LLama/OldVersion/IChatModel.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace LLama.Old +namespace LLama.OldVersion { public interface IChatModel { diff --git a/LLama/Old/LLamaEmbedder.cs b/LLama/OldVersion/LLamaEmbedder.cs similarity index 98% rename from LLama/Old/LLamaEmbedder.cs rename to LLama/OldVersion/LLamaEmbedder.cs index 9de2f58e..823c4437 100644 --- a/LLama/Old/LLamaEmbedder.cs +++ b/LLama/OldVersion/LLamaEmbedder.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Text; using LLama.Exceptions; -namespace LLama.Old +namespace LLama.OldVersion { public class LLamaEmbedder : IDisposable { diff --git a/LLama/Old/LLamaModel.cs b/LLama/OldVersion/LLamaModel.cs similarity index 99% rename from LLama/Old/LLamaModel.cs rename to LLama/OldVersion/LLamaModel.cs index 0c9488bb..55ecc843 100644 --- a/LLama/Old/LLamaModel.cs +++ b/LLama/OldVersion/LLamaModel.cs @@ -9,7 +9,7 @@ using System.IO; using System.Linq; using System.Text; -namespace LLama.Old +namespace LLama.OldVersion { using llama_token = Int32; public class LLamaModel : IChatModel, IDisposable diff --git a/LLama/Old/LLamaParams.cs b/LLama/OldVersion/LLamaParams.cs similarity index 99% rename from LLama/Old/LLamaParams.cs rename to LLama/OldVersion/LLamaParams.cs index 7c349c83..a2d677d8 100644 --- a/LLama/Old/LLamaParams.cs +++ b/LLama/OldVersion/LLamaParams.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; -namespace LLama.Old +namespace LLama.OldVersion { using llama_token = Int32; public struct LLamaParams diff --git a/LLama/Old/LLamaTypes.cs b/LLama/OldVersion/LLamaTypes.cs similarity index 98% rename from LLama/Old/LLamaTypes.cs rename to LLama/OldVersion/LLamaTypes.cs index ae1d4af7..d0bd4ad7 100644 --- a/LLama/Old/LLamaTypes.cs +++ b/LLama/OldVersion/LLamaTypes.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace LLama.Old +namespace LLama.OldVersion { public enum ChatRole { diff --git a/LLama/Old/Utils.cs b/LLama/OldVersion/Utils.cs similarity index 99% rename from LLama/Old/Utils.cs rename to LLama/OldVersion/Utils.cs index f7203802..4916a20d 100644 --- a/LLama/Old/Utils.cs +++ b/LLama/OldVersion/Utils.cs @@ -8,7 +8,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.IO; -namespace LLama.Old +namespace LLama.OldVersion { using llama_token = Int32; internal static class Utils