From a2e29d393c1d23c9b1e163087815e206203083e3 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 20 Jan 2024 21:18:35 +0000 Subject: [PATCH] Swapped `StatelessExecutor` to use `llama_decode`! - Added `logits_i` argument to `Context.ApplyPenalty` - Added a new exception type for `llama_decode` return code --- LLama.Examples/Examples/BatchedDecoding.cs | 7 +--- LLama.Unittest/StatelessExecutorTest.cs | 3 +- LLama/Exceptions/RuntimeError.cs | 20 +++++++++++ LLama/LLamaContext.cs | 32 ++++++++--------- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 40 ++++++++++++++++------ LLama/Native/DecodeResult.cs | 22 ++++++++++++ 8 files changed, 90 insertions(+), 38 deletions(-) create mode 100644 LLama/Native/DecodeResult.cs diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 1d55ff12..37a02201 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -105,12 +105,7 @@ public class BatchedDecoding if (i_batch[i] < 0) continue; - var n_vocab = model.VocabCount; - LLamaTokenDataArray candidates; - unsafe - { - candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); - } + var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); candidates.TopK(context.NativeHandle, top_k); candidates.TopP(context.NativeHandle, top_p); diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 8d4be20c..cfe49973 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -19,6 +19,7 @@ namespace LLama.Unittest { ContextSize = 60, Seed = 1754, + BatchSize = 2, }; _weights = LLamaWeights.LoadFromFile(_params); } @@ -60,7 +61,7 @@ namespace LLama.Unittest { var executor = new StatelessExecutor(_weights, _params); - const string question = " Question. cats or dogs?\nAnswer: "; + const string question = " Question. cats or dogs?\nAnswer:"; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response // with a modified context diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index c56d78ff..0feb5366 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -1,4 +1,5 @@ using System; +using LLama.Native; namespace LLama.Exceptions; @@ -36,4 +37,23 @@ public class LoadWeightsFailedException { ModelPath = modelPath; } +} + +/// +/// `llama_decode` return a non-zero status code +/// +public class LLamaDecodeError + : RuntimeError +{ + /// + /// The return status code + /// + public DecodeResult ReturnCode { get; } + + /// + public LLamaDecodeError(DecodeResult returnCode) + : base($"llama_decode failed: '{returnCode}'") + { + ReturnCode = returnCode; + } } \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ea745d02..33c8d726 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -293,6 +293,7 @@ namespace LLama /// /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. /// + /// /// /// /// @@ -301,11 +302,11 @@ namespace LLama /// /// /// - public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, - int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, - bool penalizeNL = true) + public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable lastTokens, Dictionary? logitBias = null, + int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, + bool penalizeNL = true) { - var logits = NativeHandle.GetLogits(); + var logits = NativeHandle.GetLogitsIth(logits_i); // Apply params.logit_bias map if (logitBias is not null) @@ -348,28 +349,23 @@ namespace LLama /// /// /// - /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
- public int Decode(LLamaBatch batch) + public DecodeResult Decode(LLamaBatch batch) { - return NativeHandle.Decode(batch); + if (batch.TokenCount == 0) + return 0; + if (batch.TokenCount > Params.BatchSize) + throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); + + return (DecodeResult)NativeHandle.Decode(batch); } /// /// /// /// - /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
- public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) + public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) { - return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); + return Task.Run(() => Decode(batch), cancellationToken); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index b763145e..993019f1 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -216,7 +216,7 @@ namespace LLama } else { - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 11973a27..2e72c7ae 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -195,7 +195,7 @@ namespace LLama } else { - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index e03fe7a1..0587f148 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; -using System.Threading.Tasks; +using LLama.Exceptions; using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -22,6 +22,7 @@ namespace LLama private readonly LLamaWeights _weights; private readonly IContextParams _params; private readonly ILogger? _logger; + private readonly LLamaBatch _batch; /// /// The context used by the executor when running the inference. @@ -39,6 +40,7 @@ namespace LLama _weights = weights; _params = @params; _logger = logger; + _batch = new LLamaBatch(1); Context = _weights.CreateContext(_params, logger); Context.Dispose(); @@ -71,16 +73,29 @@ namespace LLama var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); var lastTokens = new List(repeat_last_n); for (var i = 0; i < repeat_last_n; i++) - lastTokens.Add((LLamaToken)0); + lastTokens.Add(0); // Tokenize the prompt var tokens = Context.Tokenize(prompt).ToList(); lastTokens.AddRange(tokens); - var n_past = 1 + tokens.Count; - // Evaluate the prompt - await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) - .ConfigureAwait(false); + // Evaluate the prompt, in chunks smaller than the max batch size + var n_past = 0; + var batchSize = (int)Context.Params.BatchSize; + for (var i = 0; i < tokens.Count; i += batchSize) + { + var n_eval = tokens.Count - i; + if (n_eval > batchSize) + n_eval = batchSize; + + _batch.Clear(); + for (var j = 0; j < n_eval; j++) + _batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1); + + var returnCode = await Context.DecodeAsync(_batch, cancellationToken); + if (returnCode != 0) + throw new LLamaDecodeError(returnCode); + } // Begin loop, evaluating one token at a time var mu = (float?)null; @@ -90,12 +105,12 @@ namespace LLama LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); } else { // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); // Sample a single token @@ -136,9 +151,12 @@ namespace LLama n_past -= n_discard; } - // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) - n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) - .ConfigureAwait(false); + // Evaluate with this new token + _batch.Clear(); + _batch.Add(id, n_past++, LLamaSeqId.Zero, true); + var returnCode = await context.DecodeAsync(_batch, cancellationToken); + if (returnCode != 0) + throw new LLamaDecodeError(returnCode); } } } diff --git a/LLama/Native/DecodeResult.cs b/LLama/Native/DecodeResult.cs new file mode 100644 index 00000000..61056dd9 --- /dev/null +++ b/LLama/Native/DecodeResult.cs @@ -0,0 +1,22 @@ +namespace LLama.Native; + +/// +/// Return codes from llama_decode +/// +public enum DecodeResult +{ + /// + /// An unspecified error + /// + Error = -1, + + /// + /// Ok. + /// + Ok = 0, + + /// + /// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + /// + NoKvSlot = 1, +} \ No newline at end of file