Browse Source

`ReadOnlySpan<float>` in ISamplingPipeline (#538)

* - Modified ISamplingPipeline to accept `ReadOnlySpan<float>` of logits directly. This moves responsibility to copy the logits into the pipeline.
 - Added a flag to `BaseSamplingPipeline` indicating if a logit copy is necessary. Skipping it in most cases.

* Fixed `RestoreProtectedTokens` not working if logit processing is skipped

* - Implemented a new greedy sampling pipeline (always sample most likely token)
 - Moved `Grammar` into `BaseSamplingPipeline`
 - Removed "protected tokens" concept from `BaseSamplingPipeline`. Was introducing a lot of incidental complexity.
 - Implemented newline logit save/restore in `DefaultSamplingPipeline` (only place protected tokens was used)

* Implemented pipelines for mirostat v1 and v2
tags/0.11.0
Martin Evans GitHub 1 year ago
parent
commit
91a7967869
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
9 changed files with 271 additions and 108 deletions
  1. +1
    -2
      LLama.Examples/Examples/BatchedExecutorFork.cs
  2. +3
    -5
      LLama.Examples/Examples/BatchedExecutorRewind.cs
  3. +2
    -1
      LLama/LLamaContext.cs
  4. +14
    -72
      LLama/Sampling/BaseSamplingPipeline.cs
  5. +74
    -26
      LLama/Sampling/DefaultSamplingPipeline.cs
  6. +32
    -0
      LLama/Sampling/GreedySamplingPipeline.cs
  7. +2
    -2
      LLama/Sampling/ISamplingPipeline.cs
  8. +71
    -0
      LLama/Sampling/Mirostat2SamplingPipeline.cs
  9. +72
    -0
      LLama/Sampling/MirostatSamplingPipeline.cs

+ 1
- 2
LLama.Examples/Examples/BatchedExecutorFork.cs View File

@@ -91,8 +91,7 @@ public class BatchedExecutorFork


// Sample one token // Sample one token
var ctx = _conversation.Executor.Context.NativeHandle; var ctx = _conversation.Executor.Context.NativeHandle;
var logitsCopy = _conversation.Sample().ToArray();
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token); _sampler.Accept(ctx, token);
_decoder.Add(token); _decoder.Add(token);




+ 3
- 5
LLama.Examples/Examples/BatchedExecutorRewind.cs View File

@@ -88,7 +88,7 @@ public class BatchedExecutorRewind


public LLamaToken Sample(Conversation conversation) public LLamaToken Sample(Conversation conversation)
{ {
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>());
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
_tokens.Add(token); _tokens.Add(token);
return token; return token;
} }
@@ -100,14 +100,12 @@ public class BatchedExecutorRewind
for (var i = 0; i < _tokens.Count - n_rewind; i++) for (var i = 0; i < _tokens.Count - n_rewind; i++)
decoder.Add(_tokens[i]); decoder.Add(_tokens[i]);


Console.ForegroundColor = ConsoleColor.Green;
Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" "));
AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");


for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
decoder.Add(_tokens[i]); decoder.Add(_tokens[i]);


Console.ForegroundColor = ConsoleColor.DarkRed;
Console.WriteLine(decoder.Read().ReplaceLineEndings(" "));
AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
} }


public LLamaToken GetToken(int index) public LLamaToken GetToken(int index)


+ 2
- 1
LLama/LLamaContext.cs View File

@@ -147,8 +147,9 @@ namespace LLama
} }


/// <summary> /// <summary>
/// Get the state data as an opaque handle
/// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
/// </summary> /// </summary>
/// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks>
/// <returns></returns> /// <returns></returns>
public State GetState() public State GetState()
{ {


+ 14
- 72
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -1,6 +1,4 @@
using System; using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native; using LLama.Native;


namespace LLama.Sampling; namespace LLama.Sampling;
@@ -11,84 +9,28 @@ namespace LLama.Sampling;
public abstract class BaseSamplingPipeline public abstract class BaseSamplingPipeline
: ISamplingPipeline : ISamplingPipeline
{ {
private int _savedLogitsCount;
private (LLamaToken index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[(int)index];
_savedLogits[i] = (index, value);
}

// Process raw logits
ProcessLogits(ctx, logits, lastTokens);

// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);

// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);

// Process token data array
return ProcessTokenDataArray(ctx, candidates, lastTokens);
}
finally
{
ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}

/// <inheritdoc />
public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token);

#region protected tokens
/// <summary> /// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// Grammar to constrain valid tokens
/// </summary> /// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx);
public SafeLLamaGrammarHandle? Grammar { get; set; }


/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="logits"></param>
protected void RestoreProtectedTokens(Span<float> logits)
/// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
if (_savedLogits == null)
return;
// Apply processing to raw logit values
logits = ProcessLogits(ctx, logits, lastTokens);


// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[(int)saved[i].index] = saved[i].logit;
// Process token data array to select a final token
var candidates = LLamaTokenDataArray.Create(logits);
candidates.ApplyGrammar(ctx, Grammar);
return ProcessTokenDataArray(ctx, candidates, lastTokens);
} }


/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="candidates"></param>
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
/// <inheritdoc />
public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{ {
if (_savedLogits == null || _savedLogits.Length == 0)
return;

candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
Grammar?.AcceptToken(ctx, token);
} }
#endregion


/// <summary> /// <summary>
/// Process the raw logit values /// Process the raw logit values
@@ -96,7 +38,7 @@ public abstract class BaseSamplingPipeline
/// <param name="ctx">The context being sampled from</param> /// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param> /// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);


/// <summary> /// <summary>
/// Process the LLamaTokenDataArray and select a single token /// Process the LLamaTokenDataArray and select a single token


+ 74
- 26
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -16,15 +16,10 @@ public sealed class DefaultSamplingPipeline
/// </summary> /// </summary>
public Dictionary<int, float> LogitBias { get; } = new(); public Dictionary<int, float> LogitBias { get; } = new();


/// <summary>
/// Grammar to constrain valid tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary> /// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary> /// </summary>
public float RepeatPenalty { get; set; } = 1.1f;
public float RepeatPenalty { get; set; }


/// <summary> /// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
@@ -43,7 +38,7 @@ public sealed class DefaultSamplingPipeline
_alphaFreq = value; _alphaFreq = value;
} }
} }
private float _alphaFreq = 0.1f;
private float _alphaFreq;


/// <summary> /// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
@@ -62,7 +57,7 @@ public sealed class DefaultSamplingPipeline
_alphaPresence = value; _alphaPresence = value;
} }
} }
private float _alphaPresence = 0.1f;
private float _alphaPresence;


/// <summary> /// <summary>
/// Temperature to apply (higher temperature is more "creative") /// Temperature to apply (higher temperature is more "creative")
@@ -99,33 +94,46 @@ public sealed class DefaultSamplingPipeline
/// </summary> /// </summary>
public bool PenalizeNewline { get; set; } = false; public bool PenalizeNewline { get; set; } = false;


private readonly LLamaToken[] _newlineToken = new LLamaToken[1];
private float[]? _logits;


/// <inheritdoc /> /// <inheritdoc />
protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx)
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
if (PenalizeNewline)
return Array.Empty<LLamaToken>();
// Skip work if possible
if (LogitBias.Count == 0)
return logits;


_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}
// Create a temporary array to hold logits
if (_logits == null || _logits.Length < logits.Length)
_logits = new float[logits.Length];


/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
// Copy logits
logits.CopyTo(_logits);
var mutable = _logits.AsSpan(0, logits.Length);

// Apply logit bias
foreach (var (key, value) in LogitBias) foreach (var (key, value) in LogitBias)
logits[key] += value;
mutable[key] += value;

return mutable;
} }


/// <inheritdoc /> /// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{ {
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
// Only apply repetition penalty if we really must. Otherwise avoid all this work
if (lastTokens.Length > 0 && (RepeatPenalty != 0 || AlphaFrequency != 0 || AlphaPresence != 0))
{
// Save the logit value for the newline token
var (nlIndex, nlLogit) = PenalizeNewline ? GetNewlineLogit(ctx, candidates) : (-1, 0);

// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);


// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);
// Restore newline token
if (!PenalizeNewline)
SetNewlineLogit(ctx, candidates, nlIndex, nlLogit);
}


// Apply the normal llama.cpp pipeline // Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar); candidates.ApplyGrammar(ctx, Grammar);
@@ -135,12 +143,52 @@ public sealed class DefaultSamplingPipeline
candidates.TopP(ctx, TopP); candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP); candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature); candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);
return candidates.SampleToken(ctx);
}

private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);

// Try using the ID as an index
if (candidates.data.Span[(int)nlToken].id == nlToken)
return ((int)nlToken, candidates.data.Span[(int)nlToken].logit);
// Exhaustive search
var span = candidates.data.Span;
for (var i = 0; i < span.Length; i++)
{
if (span[i].id == nlToken)
return (i, span[i].logit);
}


Grammar?.AcceptToken(ctx, id);
return id;
return (-1, 0);
} }


private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit)
{
var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);

// Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order
if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken)
{
candidates.data.Span[indexHint].logit = logit;
return;
}

// Didn't find it, do an exhaustive search for it
var span = candidates.data.Span;
for (var i = 0; i < candidates.data.Length; i++)
{
if (span[i].id == nlToken)
{
span[i].logit = logit;
return;
}
}
}

/// <inheritdoc />
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{ {
Grammar?.AcceptToken(ctx, token); Grammar?.AcceptToken(ctx, token);


+ 32
- 0
LLama/Sampling/GreedySamplingPipeline.cs View File

@@ -0,0 +1,32 @@
using System;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// A sampling pipeline which always selects the most likely token
/// </summary>
public class GreedySamplingPipeline
: BaseSamplingPipeline
{
/// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
return logits;
}

/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
return candidates.SampleTokenGreedy(ctx);
}

/// <inheritdoc />
public override ISamplingPipeline Clone()
{
return new GreedySamplingPipeline
{
Grammar = Grammar?.Clone()
};
}
}

+ 2
- 2
LLama/Sampling/ISamplingPipeline.cs View File

@@ -19,7 +19,7 @@ public interface ISamplingPipeline
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param> /// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns> /// <returns></returns>
LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);


/// <summary> /// <summary>
/// Update the pipeline, with knowledge that a particular token was just accepted /// Update the pipeline, with knowledge that a particular token was just accepted
@@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param> /// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns> /// <returns></returns>
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens)
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens); var span = CollectionsMarshal.AsSpan(lastTokens);


+ 71
- 0
LLama/Sampling/Mirostat2SamplingPipeline.cs View File

@@ -0,0 +1,71 @@
using System;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// A sampling pipeline which uses mirostat (v2) to select tokens
/// </summary>
public class Mirostate2SamplingPipeline
: BaseSamplingPipeline
{
private const float DEFAULT_TAU = 5;

private float _mu = DEFAULT_TAU * 2;
/// <summary>
/// Currently learned mu value
/// </summary>
public float Mu => _mu;

private float _tau = DEFAULT_TAU;
/// <summary>
/// target entropy
/// </summary>
public float Tau
{
get => _tau;
set
{
_tau = value;
_mu = value * 2;
}
}

/// <summary>
/// learning rate
/// </summary>
public float Eta { get; set; } = 0.1f;

/// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
return logits;
}

/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
}

/// <inheritdoc />
public override void Reset()
{
base.Reset();

_mu = Tau * 2;
}

/// <inheritdoc />
public override ISamplingPipeline Clone()
{
return new Mirostate2SamplingPipeline
{
Grammar = Grammar?.Clone(),

_mu = _mu,
_tau = _tau,
Eta = Eta
};
}
}

+ 72
- 0
LLama/Sampling/MirostatSamplingPipeline.cs View File

@@ -0,0 +1,72 @@
using System;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// A sampling pipeline which uses mirostat (v1) to select tokens
/// </summary>
public class MirostateSamplingPipeline
: BaseSamplingPipeline
{
private const int MIROSTAT_M = 100;
private const float DEFAULT_TAU = 5;

private float _mu = DEFAULT_TAU * 2;
/// <summary>
/// Currently learned mu value
/// </summary>
public float Mu => _mu;

private float _tau = DEFAULT_TAU;
/// <summary>
/// target entropy
/// </summary>
public float Tau
{
get => _tau;
set
{
_tau = value;
_mu = value * 2;
}
}

/// <summary>
/// learning rate
/// </summary>
public float Eta { get; set; } = 0.1f;

/// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
return logits;
}

/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu);
}

/// <inheritdoc />
public override void Reset()
{
base.Reset();

_mu = Tau * 2;
}

/// <inheritdoc />
public override ISamplingPipeline Clone()
{
return new MirostateSamplingPipeline
{
Grammar = Grammar?.Clone(),

_mu = _mu,
_tau = _tau,
Eta = Eta
};
}
}

Loading…
Cancel
Save