- Added `logits_i` argument to `Context.ApplyPenalty` - Added a new exception type for `llama_decode` return codetags/v0.10.0
| @@ -105,12 +105,7 @@ public class BatchedDecoding | |||
| if (i_batch[i] < 0) | |||
| continue; | |||
| var n_vocab = model.VocabCount; | |||
| LLamaTokenDataArray candidates; | |||
| unsafe | |||
| { | |||
| candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | |||
| } | |||
| var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); | |||
| candidates.TopK(context.NativeHandle, top_k); | |||
| candidates.TopP(context.NativeHandle, top_p); | |||
| @@ -19,6 +19,7 @@ namespace LLama.Unittest | |||
| { | |||
| ContextSize = 60, | |||
| Seed = 1754, | |||
| BatchSize = 2, | |||
| }; | |||
| _weights = LLamaWeights.LoadFromFile(_params); | |||
| } | |||
| @@ -60,7 +61,7 @@ namespace LLama.Unittest | |||
| { | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| const string question = " Question. cats or dogs?\nAnswer: "; | |||
| const string question = " Question. cats or dogs?\nAnswer:"; | |||
| // The context size is set to 60. Generate more than that, forcing it to generate a coherent response | |||
| // with a modified context | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Exceptions; | |||
| @@ -36,4 +37,23 @@ public class LoadWeightsFailedException | |||
| { | |||
| ModelPath = modelPath; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// `llama_decode` return a non-zero status code | |||
| /// </summary> | |||
| public class LLamaDecodeError | |||
| : RuntimeError | |||
| { | |||
| /// <summary> | |||
| /// The return status code | |||
| /// </summary> | |||
| public DecodeResult ReturnCode { get; } | |||
| /// <inheritdoc /> | |||
| public LLamaDecodeError(DecodeResult returnCode) | |||
| : base($"llama_decode failed: '{returnCode}'") | |||
| { | |||
| ReturnCode = returnCode; | |||
| } | |||
| } | |||
| @@ -293,6 +293,7 @@ namespace LLama | |||
| /// <summary> | |||
| /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. | |||
| /// </summary> | |||
| /// <param name="logits_i"></param> | |||
| /// <param name="lastTokens"></param> | |||
| /// <param name="logitBias"></param> | |||
| /// <param name="repeatLastTokensCount"></param> | |||
| @@ -301,11 +302,11 @@ namespace LLama | |||
| /// <param name="alphaPresence"></param> | |||
| /// <param name="penalizeNL"></param> | |||
| /// <returns></returns> | |||
| public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null, | |||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||
| bool penalizeNL = true) | |||
| public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null, | |||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||
| bool penalizeNL = true) | |||
| { | |||
| var logits = NativeHandle.GetLogits(); | |||
| var logits = NativeHandle.GetLogitsIth(logits_i); | |||
| // Apply params.logit_bias map | |||
| if (logitBias is not null) | |||
| @@ -348,28 +349,23 @@ namespace LLama | |||
| /// <summary> | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||
| /// - 0: success<br /> | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public int Decode(LLamaBatch batch) | |||
| public DecodeResult Decode(LLamaBatch batch) | |||
| { | |||
| return NativeHandle.Decode(batch); | |||
| if (batch.TokenCount == 0) | |||
| return 0; | |||
| if (batch.TokenCount > Params.BatchSize) | |||
| throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); | |||
| return (DecodeResult)NativeHandle.Decode(batch); | |||
| } | |||
| /// <summary> | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||
| /// - 0: success<br /> | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||
| public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||
| { | |||
| return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); | |||
| return Task.Run(() => Decode(batch), cancellationToken); | |||
| } | |||
| /// <summary> | |||
| @@ -216,7 +216,7 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| @@ -195,7 +195,7 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| @@ -5,7 +5,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| @@ -22,6 +22,7 @@ namespace LLama | |||
| private readonly LLamaWeights _weights; | |||
| private readonly IContextParams _params; | |||
| private readonly ILogger? _logger; | |||
| private readonly LLamaBatch _batch; | |||
| /// <summary> | |||
| /// The context used by the executor when running the inference. | |||
| @@ -39,6 +40,7 @@ namespace LLama | |||
| _weights = weights; | |||
| _params = @params; | |||
| _logger = logger; | |||
| _batch = new LLamaBatch(1); | |||
| Context = _weights.CreateContext(_params, logger); | |||
| Context.Dispose(); | |||
| @@ -71,16 +73,29 @@ namespace LLama | |||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | |||
| var lastTokens = new List<LLamaToken>(repeat_last_n); | |||
| for (var i = 0; i < repeat_last_n; i++) | |||
| lastTokens.Add((LLamaToken)0); | |||
| lastTokens.Add(0); | |||
| // Tokenize the prompt | |||
| var tokens = Context.Tokenize(prompt).ToList(); | |||
| lastTokens.AddRange(tokens); | |||
| var n_past = 1 + tokens.Count; | |||
| // Evaluate the prompt | |||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||
| .ConfigureAwait(false); | |||
| // Evaluate the prompt, in chunks smaller than the max batch size | |||
| var n_past = 0; | |||
| var batchSize = (int)Context.Params.BatchSize; | |||
| for (var i = 0; i < tokens.Count; i += batchSize) | |||
| { | |||
| var n_eval = tokens.Count - i; | |||
| if (n_eval > batchSize) | |||
| n_eval = batchSize; | |||
| _batch.Clear(); | |||
| for (var j = 0; j < n_eval; j++) | |||
| _batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1); | |||
| var returnCode = await Context.DecodeAsync(_batch, cancellationToken); | |||
| if (returnCode != 0) | |||
| throw new LLamaDecodeError(returnCode); | |||
| } | |||
| // Begin loop, evaluating one token at a time | |||
| var mu = (float?)null; | |||
| @@ -90,12 +105,12 @@ namespace LLama | |||
| LLamaToken id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); | |||
| } | |||
| else | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| @@ -136,9 +151,12 @@ namespace LLama | |||
| n_past -= n_discard; | |||
| } | |||
| // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) | |||
| n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) | |||
| .ConfigureAwait(false); | |||
| // Evaluate with this new token | |||
| _batch.Clear(); | |||
| _batch.Add(id, n_past++, LLamaSeqId.Zero, true); | |||
| var returnCode = await context.DecodeAsync(_batch, cancellationToken); | |||
| if (returnCode != 0) | |||
| throw new LLamaDecodeError(returnCode); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// Return codes from llama_decode | |||
| /// </summary> | |||
| public enum DecodeResult | |||
| { | |||
| /// <summary> | |||
| /// An unspecified error | |||
| /// </summary> | |||
| Error = -1, | |||
| /// <summary> | |||
| /// Ok. | |||
| /// </summary> | |||
| Ok = 0, | |||
| /// <summary> | |||
| /// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | |||
| /// </summary> | |||
| NoKvSlot = 1, | |||
| } | |||