Swapped `StatelessExecutor` to use `llama_decode`!tags/v0.10.0
| @@ -105,12 +105,7 @@ public class BatchedDecoding | |||||
| if (i_batch[i] < 0) | if (i_batch[i] < 0) | ||||
| continue; | continue; | ||||
| var n_vocab = model.VocabCount; | |||||
| LLamaTokenDataArray candidates; | |||||
| unsafe | |||||
| { | |||||
| candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | |||||
| } | |||||
| var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); | |||||
| candidates.TopK(context.NativeHandle, top_k); | candidates.TopK(context.NativeHandle, top_k); | ||||
| candidates.TopP(context.NativeHandle, top_p); | candidates.TopP(context.NativeHandle, top_p); | ||||
| @@ -19,6 +19,7 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| ContextSize = 60, | ContextSize = 60, | ||||
| Seed = 1754, | Seed = 1754, | ||||
| BatchSize = 2, | |||||
| }; | }; | ||||
| _weights = LLamaWeights.LoadFromFile(_params); | _weights = LLamaWeights.LoadFromFile(_params); | ||||
| } | } | ||||
| @@ -60,7 +61,7 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| const string question = " Question. cats or dogs?\nAnswer: "; | |||||
| const string question = " Question. cats or dogs?\nAnswer:"; | |||||
| // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | ||||
| // with a modified context | // with a modified context | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using LLama.Native; | |||||
| namespace LLama.Exceptions; | namespace LLama.Exceptions; | ||||
| @@ -36,4 +37,23 @@ public class LoadWeightsFailedException | |||||
| { | { | ||||
| ModelPath = modelPath; | ModelPath = modelPath; | ||||
| } | } | ||||
| } | |||||
| /// <summary> | |||||
| /// `llama_decode` return a non-zero status code | |||||
| /// </summary> | |||||
| public class LLamaDecodeError | |||||
| : RuntimeError | |||||
| { | |||||
| /// <summary> | |||||
| /// The return status code | |||||
| /// </summary> | |||||
| public DecodeResult ReturnCode { get; } | |||||
| /// <inheritdoc /> | |||||
| public LLamaDecodeError(DecodeResult returnCode) | |||||
| : base($"llama_decode failed: '{returnCode}'") | |||||
| { | |||||
| ReturnCode = returnCode; | |||||
| } | |||||
| } | } | ||||
| @@ -293,6 +293,7 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. | /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="logits_i"></param> | |||||
| /// <param name="lastTokens"></param> | /// <param name="lastTokens"></param> | ||||
| /// <param name="logitBias"></param> | /// <param name="logitBias"></param> | ||||
| /// <param name="repeatLastTokensCount"></param> | /// <param name="repeatLastTokensCount"></param> | ||||
| @@ -301,11 +302,11 @@ namespace LLama | |||||
| /// <param name="alphaPresence"></param> | /// <param name="alphaPresence"></param> | ||||
| /// <param name="penalizeNL"></param> | /// <param name="penalizeNL"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null, | |||||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||||
| bool penalizeNL = true) | |||||
| public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null, | |||||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||||
| bool penalizeNL = true) | |||||
| { | { | ||||
| var logits = NativeHandle.GetLogits(); | |||||
| var logits = NativeHandle.GetLogitsIth(logits_i); | |||||
| // Apply params.logit_bias map | // Apply params.logit_bias map | ||||
| if (logitBias is not null) | if (logitBias is not null) | ||||
| @@ -348,28 +349,23 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="batch"></param> | /// <param name="batch"></param> | ||||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||||
| /// - 0: success<br /> | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||||
| /// - < 0: error<br /> | |||||
| /// </returns> | |||||
| public int Decode(LLamaBatch batch) | |||||
| public DecodeResult Decode(LLamaBatch batch) | |||||
| { | { | ||||
| return NativeHandle.Decode(batch); | |||||
| if (batch.TokenCount == 0) | |||||
| return 0; | |||||
| if (batch.TokenCount > Params.BatchSize) | |||||
| throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); | |||||
| return (DecodeResult)NativeHandle.Decode(batch); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="batch"></param> | /// <param name="batch"></param> | ||||
| /// <param name="cancellationToken"></param> | /// <param name="cancellationToken"></param> | ||||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||||
| /// - 0: success<br /> | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||||
| /// - < 0: error<br /> | |||||
| /// </returns> | |||||
| public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||||
| public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); | |||||
| return Task.Run(() => Decode(batch), cancellationToken); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -216,7 +216,7 @@ namespace LLama | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| @@ -195,7 +195,7 @@ namespace LLama | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| var mu = MirostatMu; | var mu = MirostatMu; | ||||
| @@ -5,7 +5,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | using LLama.Sampling; | ||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| @@ -22,6 +22,7 @@ namespace LLama | |||||
| private readonly LLamaWeights _weights; | private readonly LLamaWeights _weights; | ||||
| private readonly IContextParams _params; | private readonly IContextParams _params; | ||||
| private readonly ILogger? _logger; | private readonly ILogger? _logger; | ||||
| private readonly LLamaBatch _batch; | |||||
| /// <summary> | /// <summary> | ||||
| /// The context used by the executor when running the inference. | /// The context used by the executor when running the inference. | ||||
| @@ -39,6 +40,7 @@ namespace LLama | |||||
| _weights = weights; | _weights = weights; | ||||
| _params = @params; | _params = @params; | ||||
| _logger = logger; | _logger = logger; | ||||
| _batch = new LLamaBatch(1); | |||||
| Context = _weights.CreateContext(_params, logger); | Context = _weights.CreateContext(_params, logger); | ||||
| Context.Dispose(); | Context.Dispose(); | ||||
| @@ -71,16 +73,29 @@ namespace LLama | |||||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | ||||
| var lastTokens = new List<LLamaToken>(repeat_last_n); | var lastTokens = new List<LLamaToken>(repeat_last_n); | ||||
| for (var i = 0; i < repeat_last_n; i++) | for (var i = 0; i < repeat_last_n; i++) | ||||
| lastTokens.Add((LLamaToken)0); | |||||
| lastTokens.Add(0); | |||||
| // Tokenize the prompt | // Tokenize the prompt | ||||
| var tokens = Context.Tokenize(prompt).ToList(); | var tokens = Context.Tokenize(prompt).ToList(); | ||||
| lastTokens.AddRange(tokens); | lastTokens.AddRange(tokens); | ||||
| var n_past = 1 + tokens.Count; | |||||
| // Evaluate the prompt | |||||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||||
| .ConfigureAwait(false); | |||||
| // Evaluate the prompt, in chunks smaller than the max batch size | |||||
| var n_past = 0; | |||||
| var batchSize = (int)Context.Params.BatchSize; | |||||
| for (var i = 0; i < tokens.Count; i += batchSize) | |||||
| { | |||||
| var n_eval = tokens.Count - i; | |||||
| if (n_eval > batchSize) | |||||
| n_eval = batchSize; | |||||
| _batch.Clear(); | |||||
| for (var j = 0; j < n_eval; j++) | |||||
| _batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1); | |||||
| var returnCode = await Context.DecodeAsync(_batch, cancellationToken); | |||||
| if (returnCode != 0) | |||||
| throw new LLamaDecodeError(returnCode); | |||||
| } | |||||
| // Begin loop, evaluating one token at a time | // Begin loop, evaluating one token at a time | ||||
| var mu = (float?)null; | var mu = (float?)null; | ||||
| @@ -90,12 +105,12 @@ namespace LLama | |||||
| LLamaToken id; | LLamaToken id; | ||||
| if (inferenceParams.SamplingPipeline is not null) | if (inferenceParams.SamplingPipeline is not null) | ||||
| { | { | ||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // Penalize the generated tokens by various penalties | // Penalize the generated tokens by various penalties | ||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | ||||
| // Sample a single token | // Sample a single token | ||||
| @@ -136,9 +151,12 @@ namespace LLama | |||||
| n_past -= n_discard; | n_past -= n_discard; | ||||
| } | } | ||||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | |||||
| n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) | |||||
| .ConfigureAwait(false); | |||||
| // Evaluate with this new token | |||||
| _batch.Clear(); | |||||
| _batch.Add(id, n_past++, LLamaSeqId.Zero, true); | |||||
| var returnCode = await context.DecodeAsync(_batch, cancellationToken); | |||||
| if (returnCode != 0) | |||||
| throw new LLamaDecodeError(returnCode); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,22 @@ | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// Return codes from llama_decode | |||||
| /// </summary> | |||||
| public enum DecodeResult | |||||
| { | |||||
| /// <summary> | |||||
| /// An unspecified error | |||||
| /// </summary> | |||||
| Error = -1, | |||||
| /// <summary> | |||||
| /// Ok. | |||||
| /// </summary> | |||||
| Ok = 0, | |||||
| /// <summary> | |||||
| /// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | |||||
| /// </summary> | |||||
| NoKvSlot = 1, | |||||
| } | |||||