diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index e1e89414..6156ade8 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,7 +1,9 @@ using System.Collections.Generic; using LLama.Common; +using LLama.Control; using LLama.Native; using LLama.Sampling; +using LLama.Transform; namespace LLama.Abstractions { @@ -114,5 +116,15 @@ namespace LLama.Abstractions /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! /// ISamplingPipeline? SamplingPipeline { get; set; } + + /// + /// Set a custom generation control to use. If this is set antiprompt will be ignored! + /// + IGenerationControl GenerationControl { get; set; } + + /// + /// Set a custom tokenizer to use. + /// + ITokenizer Tokenizer { get; set; } } } \ No newline at end of file diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index ef5453a7..4fd8600d 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -18,8 +18,8 @@ namespace LLama.Abstractions /// /// Your prompt /// Any additional parameters - /// A cancellation token. + /// A cancellation token. /// - IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); + IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default); } } diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index c1f39550..a47408f1 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -3,6 +3,9 @@ using System; using System.Collections.Generic; using LLama.Native; using LLama.Sampling; +using LLama.Control; +using LLama.Transform; +using System.Text; namespace LLama.Common { @@ -80,6 +83,12 @@ namespace LLama.Common /// public ISamplingPipeline? SamplingPipeline { get; set; } + + /// + public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl(); + + /// + public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8); } /// diff --git a/LLama/AntipromptProcessor.cs b/LLama/Control/AntipromptProcessor.cs similarity index 99% rename from LLama/AntipromptProcessor.cs rename to LLama/Control/AntipromptProcessor.cs index c18c0915..9b5ff987 100644 --- a/LLama/AntipromptProcessor.cs +++ b/LLama/Control/AntipromptProcessor.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; -namespace LLama +namespace LLama.Control { /// /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts diff --git a/LLama/Control/DefaultGenerationControl.cs b/LLama/Control/DefaultGenerationControl.cs new file mode 100644 index 00000000..fe8f7d94 --- /dev/null +++ b/LLama/Control/DefaultGenerationControl.cs @@ -0,0 +1,42 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Control +{ + /// + /// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited. + /// Note that this class has state. The previous outputs feeded to it will affect its control. + /// If you use it in a session, please don't reuse it for another session unless you intend to do so. + /// + public sealed class DefaultGenerationControl: IGenerationControl + { + private AntipromptProcessor _antipromptProcessor; + + /// + /// + /// + public DefaultGenerationControl() + { + _antipromptProcessor = new AntipromptProcessor(); + } + + /// + /// + /// + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) + { + _antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts); + return _antipromptProcessor.Add(lastOutputText); + } + + /// + /// + /// + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + { + return false; + } + } +} diff --git a/LLama/Control/IGenerationControl.cs b/LLama/Control/IGenerationControl.cs new file mode 100644 index 00000000..3e01d284 --- /dev/null +++ b/LLama/Control/IGenerationControl.cs @@ -0,0 +1,31 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Control +{ + /// + /// Control the text generation of LLama Executors. + /// + public interface IGenerationControl + { + /// + /// Use the last output text to determine if the generation should stop. + /// + /// + /// + /// + /// + bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); + + /// + /// Use the last output ids to determine if the generation should stop. + /// + /// + /// + /// + /// + bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds); + } +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e0fde1ed..7b7ccd24 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -2,6 +2,7 @@ using LLama.Common; using LLama.Exceptions; using LLama.Native; +using LLama.Transform; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 831aceb2..75101674 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using LLama.Native; using LLama.Sampling; +using LLama.Control; using Microsoft.Extensions.Logging; namespace LLama @@ -49,7 +50,7 @@ namespace LLama /// public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // Ensure the context from last time is disposed (it always hould be) + // Ensure the context from last time is disposed (it always should be) if (!Context.NativeHandle.IsClosed) Context.Dispose(); @@ -57,48 +58,53 @@ namespace LLama using var context = _weights.CreateContext(_params, _logger); Context = context; + await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken)) + { + yield return item; + } + } + + public static async IAsyncEnumerable InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Sanity check inference params inferenceParams ??= new InferenceParams(); - if (inferenceParams.TokensKeep > Context.ContextSize) - throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); - - // Create decoders for the token stream - var decoder = new StreamingTokenDecoder(Context); - var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); + if (inferenceParams.TokensKeep > context.ContextSize) + throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})"); // Keep track of the last N tokens emitted - var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); + var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount); var lastTokens = new List(repeat_last_n); for (var i = 0; i < repeat_last_n; i++) lastTokens.Add(0); // Tokenize the prompt - var tokens = Context.Tokenize(prompt).ToList(); + var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList(); lastTokens.AddRange(tokens); var n_past = 1 + tokens.Count; // Evaluate the prompt - await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) + await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken) .ConfigureAwait(false); // Begin loop, evaluating one token at a time var mu = (float?)null; var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) + for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { llama_token id; if (inferenceParams.SamplingPipeline is not null) { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens); } else { // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); // Sample a single token - id = Context.Sample( + id = context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, inferenceParams.MinP @@ -106,12 +112,11 @@ namespace LLama } // Decode this token into text - decoder.Add(id); - var decoded = decoder.Read(); + var decoded = inferenceParams.Tokenizer.Detokenize(context, id); yield return decoded; - // Check if any of the antiprompts have been generated - if (antiprocessor.Add(decoded)) + // Check if the generation should stop + if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded)) break; lastTokens.Add(id); @@ -120,19 +125,19 @@ namespace LLama // when run out of context // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 - if (n_past + tokens.Count >= Context.ContextSize) + if (n_past + tokens.Count >= context.ContextSize) { var n_left = n_past - inferenceParams.TokensKeep - 1; var n_discard = n_left / 2; - NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); - NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); + NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); + NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; } // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) - n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) + n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken) .ConfigureAwait(false); } } diff --git a/LLama/TextCompletion.cs b/LLama/TextCompletion.cs new file mode 100644 index 00000000..59b6ba98 --- /dev/null +++ b/LLama/TextCompletion.cs @@ -0,0 +1,31 @@ +using LLama.Abstractions; +using LLama.Common; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; + +namespace LLama +{ + /// + /// A class to execute text completion task. + /// + public class TextCompletion + { + public string Execute(string prompt, IInferenceParams? inferenceParams = null) + { + throw new NotImplementedException(); + } + + public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null) + { + throw new NotImplementedException(); + } + + public async IAsyncEnumerable StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } +} diff --git a/LLama/Transform/DefaultTokenizer.cs b/LLama/Transform/DefaultTokenizer.cs new file mode 100644 index 00000000..451c1d85 --- /dev/null +++ b/LLama/Transform/DefaultTokenizer.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Transform +{ + /// + /// The default tokenizer of LLamaSharp. This class should not be inherited. + /// Note that this class has state. The previous outputs feeded to it will affect its control. + /// If you use it in a session, please don't reuse it for another session unless you intend to do so. + /// + public sealed class DefaultTokenizer: ITokenizer + { + private Encoding _encoding; + private StreamingTokenDecoder _tokenDecoder; + + /// + /// Initialize a new tokenizer with the specified encoding. + /// + /// + public DefaultTokenizer(Encoding encoding) + { + _encoding = encoding; + _tokenDecoder = new StreamingTokenDecoder(encoding); + } + + /// + /// + /// + public IEnumerable Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false) + { + return context.Tokenize(text, addBos, special); + } + + /// + /// + /// + public string Detokenize(LLamaContext context, int token) + { + _tokenDecoder.Add(token, context.NativeHandle.ModelHandle); + return _tokenDecoder.Read(); + } + + /// + /// + /// + public string Detokenize(LLamaContext context, IEnumerable tokens) + { + _tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle); + return _tokenDecoder.Read(); + } + } +} diff --git a/LLama/Transform/ITokenizer.cs b/LLama/Transform/ITokenizer.cs new file mode 100644 index 00000000..3df9fc46 --- /dev/null +++ b/LLama/Transform/ITokenizer.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Transform +{ + public interface ITokenizer + { + IEnumerable Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false); + + string Detokenize(LLamaContext context, int token); + + string Detokenize(LLamaContext context, IEnumerable tokens); + } +} diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/Transform/StreamingTokenDecoder.cs similarity index 90% rename from LLama/StreamingTokenDecoder.cs rename to LLama/Transform/StreamingTokenDecoder.cs index f82f8c37..0653bd4d 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/Transform/StreamingTokenDecoder.cs @@ -6,14 +6,14 @@ using System.Text; using LLama.Extensions; using LLama.Native; -namespace LLama +namespace LLama.Transform { /// /// Decodes a stream of tokens into a stream of characters /// public sealed class StreamingTokenDecoder { - private readonly SafeLlamaModelHandle _weights; + private readonly SafeLlamaModelHandle? _weights; private readonly Decoder _decoder; private readonly List _characters = new(); @@ -29,8 +29,8 @@ namespace LLama /// /// Text encoding to use /// Model weights - public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) - : this(encoding, weights.NativeHandle) + public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null) + : this(encoding, weights?.NativeHandle) { } @@ -69,14 +69,19 @@ namespace LLama /// Add a single token to the decoder /// /// - public void Add(int token) + public void Add(int token, SafeLlamaModelHandle? weights = null) { + weights ??= _weights; + if(weights is null) + { + throw new NullReferenceException("No weights provided for StreamingTokenDecoder."); + } var charsArr = ArrayPool.Shared.Rent(16); var bytesArr = ArrayPool.Shared.Rent(16); try { // Convert this token into bytes - var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; + var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length; // Convert those bytes into characters var bytesOffset = 0; @@ -133,10 +138,10 @@ namespace LLama /// Add all tokens in the given enumerable /// /// - public void AddRange(IEnumerable tokens) + public void AddRange(IEnumerable tokens, SafeLlamaModelHandle? weights = null) { foreach (var item in tokens) - Add(item); + Add(item, weights); } ///