* - Modified ISamplingPipeline to accept `ReadOnlySpan<float>` of logits directly. This moves responsibility to copy the logits into the pipeline. - Added a flag to `BaseSamplingPipeline` indicating if a logit copy is necessary. Skipping it in most cases. * Fixed `RestoreProtectedTokens` not working if logit processing is skipped * - Implemented a new greedy sampling pipeline (always sample most likely token) - Moved `Grammar` into `BaseSamplingPipeline` - Removed "protected tokens" concept from `BaseSamplingPipeline`. Was introducing a lot of incidental complexity. - Implemented newline logit save/restore in `DefaultSamplingPipeline` (only place protected tokens was used) * Implemented pipelines for mirostat v1 and v2tags/0.11.0
| @@ -91,8 +91,7 @@ public class BatchedExecutorFork | |||
| // Sample one token | |||
| var ctx = _conversation.Executor.Context.NativeHandle; | |||
| var logitsCopy = _conversation.Sample().ToArray(); | |||
| var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>()); | |||
| var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>()); | |||
| _sampler.Accept(ctx, token); | |||
| _decoder.Add(token); | |||
| @@ -88,7 +88,7 @@ public class BatchedExecutorRewind | |||
| public LLamaToken Sample(Conversation conversation) | |||
| { | |||
| var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>()); | |||
| var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>()); | |||
| _tokens.Add(token); | |||
| return token; | |||
| } | |||
| @@ -100,14 +100,12 @@ public class BatchedExecutorRewind | |||
| for (var i = 0; i < _tokens.Count - n_rewind; i++) | |||
| decoder.Add(_tokens[i]); | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")); | |||
| AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]"); | |||
| for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) | |||
| decoder.Add(_tokens[i]); | |||
| Console.ForegroundColor = ConsoleColor.DarkRed; | |||
| Console.WriteLine(decoder.Read().ReplaceLineEndings(" ")); | |||
| AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]"); | |||
| } | |||
| public LLamaToken GetToken(int index) | |||
| @@ -147,8 +147,9 @@ namespace LLama | |||
| } | |||
| /// <summary> | |||
| /// Get the state data as an opaque handle | |||
| /// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/> | |||
| /// </summary> | |||
| /// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks> | |||
| /// <returns></returns> | |||
| public State GetState() | |||
| { | |||
| @@ -1,6 +1,4 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| @@ -11,84 +9,28 @@ namespace LLama.Sampling; | |||
| public abstract class BaseSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| private int _savedLogitsCount; | |||
| private (LLamaToken index, float logit)[]? _savedLogits; | |||
| /// <inheritdoc/> | |||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| var protectedLogits = GetProtectedTokens(ctx); | |||
| _savedLogitsCount = protectedLogits.Count; | |||
| _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); | |||
| try | |||
| { | |||
| // Save the values of protected logits | |||
| for (var i = 0; i < protectedLogits.Count; i++) | |||
| { | |||
| var index = protectedLogits[i]; | |||
| var value = logits[(int)index]; | |||
| _savedLogits[i] = (index, value); | |||
| } | |||
| // Process raw logits | |||
| ProcessLogits(ctx, logits, lastTokens); | |||
| // Automatically restore saved logit values after processing | |||
| RestoreProtectedTokens(logits); | |||
| // Convert logits into token candidates | |||
| var candidates = LLamaTokenDataArray.Create(logits); | |||
| // Process token data array | |||
| return ProcessTokenDataArray(ctx, candidates, lastTokens); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); | |||
| _savedLogits = null; | |||
| _savedLogitsCount = 0; | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); | |||
| #region protected tokens | |||
| /// <summary> | |||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||
| /// Grammar to constrain valid tokens | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="logits"></param> | |||
| protected void RestoreProtectedTokens(Span<float> logits) | |||
| /// <inheritdoc/> | |||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| if (_savedLogits == null) | |||
| return; | |||
| // Apply processing to raw logit values | |||
| logits = ProcessLogits(ctx, logits, lastTokens); | |||
| // The array may be bigger than necessary, get a span of the valid bit | |||
| var saved = _savedLogits.AsSpan(0, _savedLogitsCount); | |||
| // Restore the values of protected logits | |||
| for (var i = 0; i < saved.Length; i++) | |||
| logits[(int)saved[i].index] = saved[i].logit; | |||
| // Process token data array to select a final token | |||
| var candidates = LLamaTokenDataArray.Create(logits); | |||
| candidates.ApplyGrammar(ctx, Grammar); | |||
| return ProcessTokenDataArray(ctx, candidates, lastTokens); | |||
| } | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="candidates"></param> | |||
| protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) | |||
| /// <inheritdoc /> | |||
| public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | |||
| { | |||
| if (_savedLogits == null || _savedLogits.Length == 0) | |||
| return; | |||
| candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); | |||
| Grammar?.AcceptToken(ctx, token); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Process the raw logit values | |||
| @@ -96,7 +38,7 @@ public abstract class BaseSamplingPipeline | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| /// <summary> | |||
| /// Process the LLamaTokenDataArray and select a single token | |||
| @@ -16,15 +16,10 @@ public sealed class DefaultSamplingPipeline | |||
| /// </summary> | |||
| public Dictionary<int, float> LogitBias { get; } = new(); | |||
| /// <summary> | |||
| /// Grammar to constrain valid tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| public float RepeatPenalty { get; set; } | |||
| /// <summary> | |||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| @@ -43,7 +38,7 @@ public sealed class DefaultSamplingPipeline | |||
| _alphaFreq = value; | |||
| } | |||
| } | |||
| private float _alphaFreq = 0.1f; | |||
| private float _alphaFreq; | |||
| /// <summary> | |||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| @@ -62,7 +57,7 @@ public sealed class DefaultSamplingPipeline | |||
| _alphaPresence = value; | |||
| } | |||
| } | |||
| private float _alphaPresence = 0.1f; | |||
| private float _alphaPresence; | |||
| /// <summary> | |||
| /// Temperature to apply (higher temperature is more "creative") | |||
| @@ -99,33 +94,46 @@ public sealed class DefaultSamplingPipeline | |||
| /// </summary> | |||
| public bool PenalizeNewline { get; set; } = false; | |||
| private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; | |||
| private float[]? _logits; | |||
| /// <inheritdoc /> | |||
| protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| if (PenalizeNewline) | |||
| return Array.Empty<LLamaToken>(); | |||
| // Skip work if possible | |||
| if (LogitBias.Count == 0) | |||
| return logits; | |||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| return _newlineToken; | |||
| } | |||
| // Create a temporary array to hold logits | |||
| if (_logits == null || _logits.Length < logits.Length) | |||
| _logits = new float[logits.Length]; | |||
| /// <inheritdoc /> | |||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| // Copy logits | |||
| logits.CopyTo(_logits); | |||
| var mutable = _logits.AsSpan(0, logits.Length); | |||
| // Apply logit bias | |||
| foreach (var (key, value) in LogitBias) | |||
| logits[key] += value; | |||
| mutable[key] += value; | |||
| return mutable; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| // Apply penalties to candidates | |||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| // Only apply repetition penalty if we really must. Otherwise avoid all this work | |||
| if (lastTokens.Length > 0 && (RepeatPenalty != 0 || AlphaFrequency != 0 || AlphaPresence != 0)) | |||
| { | |||
| // Save the logit value for the newline token | |||
| var (nlIndex, nlLogit) = PenalizeNewline ? GetNewlineLogit(ctx, candidates) : (-1, 0); | |||
| // Apply penalties to candidates | |||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| // Restore protected tokens, so they are not affected by repetition penalties | |||
| RestoreProtectedTokens(candidates); | |||
| // Restore newline token | |||
| if (!PenalizeNewline) | |||
| SetNewlineLogit(ctx, candidates, nlIndex, nlLogit); | |||
| } | |||
| // Apply the normal llama.cpp pipeline | |||
| candidates.ApplyGrammar(ctx, Grammar); | |||
| @@ -135,12 +143,52 @@ public sealed class DefaultSamplingPipeline | |||
| candidates.TopP(ctx, TopP); | |||
| candidates.MinP(ctx, MinP); | |||
| candidates.Temperature(ctx, Temperature); | |||
| var id = candidates.SampleToken(ctx); | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| // Try using the ID as an index | |||
| if (candidates.data.Span[(int)nlToken].id == nlToken) | |||
| return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); | |||
| // Exhaustive search | |||
| var span = candidates.data.Span; | |||
| for (var i = 0; i < span.Length; i++) | |||
| { | |||
| if (span[i].id == nlToken) | |||
| return (i, span[i].logit); | |||
| } | |||
| Grammar?.AcceptToken(ctx, id); | |||
| return id; | |||
| return (-1, 0); | |||
| } | |||
| private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit) | |||
| { | |||
| var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| // Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order | |||
| if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken) | |||
| { | |||
| candidates.data.Span[indexHint].logit = logit; | |||
| return; | |||
| } | |||
| // Didn't find it, do an exhaustive search for it | |||
| var span = candidates.data.Span; | |||
| for (var i = 0; i < candidates.data.Length; i++) | |||
| { | |||
| if (span[i].id == nlToken) | |||
| { | |||
| span[i].logit = logit; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | |||
| { | |||
| Grammar?.AcceptToken(ctx, token); | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// A sampling pipeline which always selects the most likely token | |||
| /// </summary> | |||
| public class GreedySamplingPipeline | |||
| : BaseSamplingPipeline | |||
| { | |||
| /// <inheritdoc /> | |||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return logits; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return candidates.SampleTokenGreedy(ctx); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override ISamplingPipeline Clone() | |||
| { | |||
| return new GreedySamplingPipeline | |||
| { | |||
| Grammar = Grammar?.Clone() | |||
| }; | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ public interface ISamplingPipeline | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||
| /// <summary> | |||
| /// Update the pipeline, with knowledge that a particular token was just accepted | |||
| @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens) | |||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||
| @@ -0,0 +1,71 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// A sampling pipeline which uses mirostat (v2) to select tokens | |||
| /// </summary> | |||
| public class Mirostate2SamplingPipeline | |||
| : BaseSamplingPipeline | |||
| { | |||
| private const float DEFAULT_TAU = 5; | |||
| private float _mu = DEFAULT_TAU * 2; | |||
| /// <summary> | |||
| /// Currently learned mu value | |||
| /// </summary> | |||
| public float Mu => _mu; | |||
| private float _tau = DEFAULT_TAU; | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| public float Tau | |||
| { | |||
| get => _tau; | |||
| set | |||
| { | |||
| _tau = value; | |||
| _mu = value * 2; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| public float Eta { get; set; } = 0.1f; | |||
| /// <inheritdoc /> | |||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return logits; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void Reset() | |||
| { | |||
| base.Reset(); | |||
| _mu = Tau * 2; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override ISamplingPipeline Clone() | |||
| { | |||
| return new Mirostate2SamplingPipeline | |||
| { | |||
| Grammar = Grammar?.Clone(), | |||
| _mu = _mu, | |||
| _tau = _tau, | |||
| Eta = Eta | |||
| }; | |||
| } | |||
| } | |||
| @@ -0,0 +1,72 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// A sampling pipeline which uses mirostat (v1) to select tokens | |||
| /// </summary> | |||
| public class MirostateSamplingPipeline | |||
| : BaseSamplingPipeline | |||
| { | |||
| private const int MIROSTAT_M = 100; | |||
| private const float DEFAULT_TAU = 5; | |||
| private float _mu = DEFAULT_TAU * 2; | |||
| /// <summary> | |||
| /// Currently learned mu value | |||
| /// </summary> | |||
| public float Mu => _mu; | |||
| private float _tau = DEFAULT_TAU; | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| public float Tau | |||
| { | |||
| get => _tau; | |||
| set | |||
| { | |||
| _tau = value; | |||
| _mu = value * 2; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| public float Eta { get; set; } = 0.1f; | |||
| /// <inheritdoc /> | |||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return logits; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||
| { | |||
| return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void Reset() | |||
| { | |||
| base.Reset(); | |||
| _mu = Tau * 2; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override ISamplingPipeline Clone() | |||
| { | |||
| return new MirostateSamplingPipeline | |||
| { | |||
| Grammar = Grammar?.Clone(), | |||
| _mu = _mu, | |||
| _tau = _tau, | |||
| Eta = Eta | |||
| }; | |||
| } | |||
| } | |||