Browse Source

Swapped `StatelessExecutor` to use `llama_decode`!

- Added `logits_i` argument to `Context.ApplyPenalty`
 - Added a new exception type for `llama_decode` return code
tags/v0.10.0
Martin Evans 1 year ago
parent
commit
a2e29d393c
8 changed files with 90 additions and 38 deletions
  1. +1
    -6
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +2
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  3. +20
    -0
      LLama/Exceptions/RuntimeError.cs
  4. +14
    -18
      LLama/LLamaContext.cs
  5. +1
    -1
      LLama/LLamaInstructExecutor.cs
  6. +1
    -1
      LLama/LLamaInteractExecutor.cs
  7. +29
    -11
      LLama/LLamaStatelessExecutor.cs
  8. +22
    -0
      LLama/Native/DecodeResult.cs

+ 1
- 6
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -105,12 +105,7 @@ public class BatchedDecoding
if (i_batch[i] < 0)
continue;

var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);


+ 2
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -19,6 +19,7 @@ namespace LLama.Unittest
{
ContextSize = 60,
Seed = 1754,
BatchSize = 2,
};
_weights = LLamaWeights.LoadFromFile(_params);
}
@@ -60,7 +61,7 @@ namespace LLama.Unittest
{
var executor = new StatelessExecutor(_weights, _params);

const string question = " Question. cats or dogs?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer:";

// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context


+ 20
- 0
LLama/Exceptions/RuntimeError.cs View File

@@ -1,4 +1,5 @@
using System;
using LLama.Native;

namespace LLama.Exceptions;

@@ -36,4 +37,23 @@ public class LoadWeightsFailedException
{
ModelPath = modelPath;
}
}

/// <summary>
/// `llama_decode` return a non-zero status code
/// </summary>
public class LLamaDecodeError
: RuntimeError
{
/// <summary>
/// The return status code
/// </summary>
public DecodeResult ReturnCode { get; }

/// <inheritdoc />
public LLamaDecodeError(DecodeResult returnCode)
: base($"llama_decode failed: '{returnCode}'")
{
ReturnCode = returnCode;
}
}

+ 14
- 18
LLama/LLamaContext.cs View File

@@ -293,6 +293,7 @@ namespace LLama
/// <summary>
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="logits_i"></param>
/// <param name="lastTokens"></param>
/// <param name="logitBias"></param>
/// <param name="repeatLastTokensCount"></param>
@@ -301,11 +302,11 @@ namespace LLama
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = NativeHandle.GetLogits();
var logits = NativeHandle.GetLogitsIth(logits_i);

// Apply params.logit_bias map
if (logitBias is not null)
@@ -348,28 +349,23 @@ namespace LLama
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > Params.BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>


+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -216,7 +216,7 @@ namespace LLama
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;


+ 1
- 1
LLama/LLamaInteractExecutor.cs View File

@@ -195,7 +195,7 @@ namespace LLama
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;


+ 29
- 11
LLama/LLamaStatelessExecutor.cs View File

@@ -5,7 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -22,6 +22,7 @@ namespace LLama
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
private readonly LLamaBatch _batch;

/// <summary>
/// The context used by the executor when running the inference.
@@ -39,6 +40,7 @@ namespace LLama
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch(1);

Context = _weights.CreateContext(_params, logger);
Context.Dispose();
@@ -71,16 +73,29 @@ namespace LLama
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add((LLamaToken)0);
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);

var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// Begin loop, evaluating one token at a time
var mu = (float?)null;
@@ -90,12 +105,12 @@ namespace LLama
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
@@ -136,9 +151,12 @@ namespace LLama
n_past -= n_discard;
}

// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
// Evaluate with this new token
_batch.Clear();
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
}
}


+ 22
- 0
LLama/Native/DecodeResult.cs View File

@@ -0,0 +1,22 @@
namespace LLama.Native;

/// <summary>
/// Return codes from llama_decode
/// </summary>
public enum DecodeResult
{
/// <summary>
/// An unspecified error
/// </summary>
Error = -1,

/// <summary>
/// Ok.
/// </summary>
Ok = 0,

/// <summary>
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// </summary>
NoKvSlot = 1,
}

Loading…
Cancel
Save