* - Modified ISamplingPipeline to accept `ReadOnlySpan<float>` of logits directly. This moves responsibility to copy the logits into the pipeline. - Added a flag to `BaseSamplingPipeline` indicating if a logit copy is necessary. Skipping it in most cases. * Fixed `RestoreProtectedTokens` not working if logit processing is skipped * - Implemented a new greedy sampling pipeline (always sample most likely token) - Moved `Grammar` into `BaseSamplingPipeline` - Removed "protected tokens" concept from `BaseSamplingPipeline`. Was introducing a lot of incidental complexity. - Implemented newline logit save/restore in `DefaultSamplingPipeline` (only place protected tokens was used) * Implemented pipelines for mirostat v1 and v2tags/0.11.0
| @@ -91,8 +91,7 @@ public class BatchedExecutorFork | |||||
| // Sample one token | // Sample one token | ||||
| var ctx = _conversation.Executor.Context.NativeHandle; | var ctx = _conversation.Executor.Context.NativeHandle; | ||||
| var logitsCopy = _conversation.Sample().ToArray(); | |||||
| var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>()); | |||||
| var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>()); | |||||
| _sampler.Accept(ctx, token); | _sampler.Accept(ctx, token); | ||||
| _decoder.Add(token); | _decoder.Add(token); | ||||
| @@ -88,7 +88,7 @@ public class BatchedExecutorRewind | |||||
| public LLamaToken Sample(Conversation conversation) | public LLamaToken Sample(Conversation conversation) | ||||
| { | { | ||||
| var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>()); | |||||
| var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>()); | |||||
| _tokens.Add(token); | _tokens.Add(token); | ||||
| return token; | return token; | ||||
| } | } | ||||
| @@ -100,14 +100,12 @@ public class BatchedExecutorRewind | |||||
| for (var i = 0; i < _tokens.Count - n_rewind; i++) | for (var i = 0; i < _tokens.Count - n_rewind; i++) | ||||
| decoder.Add(_tokens[i]); | decoder.Add(_tokens[i]); | ||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")); | |||||
| AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]"); | |||||
| for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) | for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) | ||||
| decoder.Add(_tokens[i]); | decoder.Add(_tokens[i]); | ||||
| Console.ForegroundColor = ConsoleColor.DarkRed; | |||||
| Console.WriteLine(decoder.Read().ReplaceLineEndings(" ")); | |||||
| AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]"); | |||||
| } | } | ||||
| public LLamaToken GetToken(int index) | public LLamaToken GetToken(int index) | ||||
| @@ -147,8 +147,9 @@ namespace LLama | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Get the state data as an opaque handle | |||||
| /// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/> | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public State GetState() | public State GetState() | ||||
| { | { | ||||
| @@ -1,6 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| namespace LLama.Sampling; | namespace LLama.Sampling; | ||||
| @@ -11,84 +9,28 @@ namespace LLama.Sampling; | |||||
| public abstract class BaseSamplingPipeline | public abstract class BaseSamplingPipeline | ||||
| : ISamplingPipeline | : ISamplingPipeline | ||||
| { | { | ||||
| private int _savedLogitsCount; | |||||
| private (LLamaToken index, float logit)[]? _savedLogits; | |||||
| /// <inheritdoc/> | |||||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| var protectedLogits = GetProtectedTokens(ctx); | |||||
| _savedLogitsCount = protectedLogits.Count; | |||||
| _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); | |||||
| try | |||||
| { | |||||
| // Save the values of protected logits | |||||
| for (var i = 0; i < protectedLogits.Count; i++) | |||||
| { | |||||
| var index = protectedLogits[i]; | |||||
| var value = logits[(int)index]; | |||||
| _savedLogits[i] = (index, value); | |||||
| } | |||||
| // Process raw logits | |||||
| ProcessLogits(ctx, logits, lastTokens); | |||||
| // Automatically restore saved logit values after processing | |||||
| RestoreProtectedTokens(logits); | |||||
| // Convert logits into token candidates | |||||
| var candidates = LLamaTokenDataArray.Create(logits); | |||||
| // Process token data array | |||||
| return ProcessTokenDataArray(ctx, candidates, lastTokens); | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); | |||||
| _savedLogits = null; | |||||
| _savedLogitsCount = 0; | |||||
| } | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); | |||||
| #region protected tokens | |||||
| /// <summary> | /// <summary> | ||||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||||
| /// Grammar to constrain valid tokens | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | |||||
| protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| /// <summary> | |||||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||||
| /// </summary> | |||||
| /// <param name="logits"></param> | |||||
| protected void RestoreProtectedTokens(Span<float> logits) | |||||
| /// <inheritdoc/> | |||||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| if (_savedLogits == null) | |||||
| return; | |||||
| // Apply processing to raw logit values | |||||
| logits = ProcessLogits(ctx, logits, lastTokens); | |||||
| // The array may be bigger than necessary, get a span of the valid bit | |||||
| var saved = _savedLogits.AsSpan(0, _savedLogitsCount); | |||||
| // Restore the values of protected logits | |||||
| for (var i = 0; i < saved.Length; i++) | |||||
| logits[(int)saved[i].index] = saved[i].logit; | |||||
| // Process token data array to select a final token | |||||
| var candidates = LLamaTokenDataArray.Create(logits); | |||||
| candidates.ApplyGrammar(ctx, Grammar); | |||||
| return ProcessTokenDataArray(ctx, candidates, lastTokens); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||||
| /// </summary> | |||||
| /// <param name="candidates"></param> | |||||
| protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) | |||||
| /// <inheritdoc /> | |||||
| public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | |||||
| { | { | ||||
| if (_savedLogits == null || _savedLogits.Length == 0) | |||||
| return; | |||||
| candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); | |||||
| Grammar?.AcceptToken(ctx, token); | |||||
| } | } | ||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// Process the raw logit values | /// Process the raw logit values | ||||
| @@ -96,7 +38,7 @@ public abstract class BaseSamplingPipeline | |||||
| /// <param name="ctx">The context being sampled from</param> | /// <param name="ctx">The context being sampled from</param> | ||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | /// <param name="lastTokens">A list of tokens recently returned by the model</param> | ||||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| /// <summary> | /// <summary> | ||||
| /// Process the LLamaTokenDataArray and select a single token | /// Process the LLamaTokenDataArray and select a single token | ||||
| @@ -16,15 +16,10 @@ public sealed class DefaultSamplingPipeline | |||||
| /// </summary> | /// </summary> | ||||
| public Dictionary<int, float> LogitBias { get; } = new(); | public Dictionary<int, float> LogitBias { get; } = new(); | ||||
| /// <summary> | |||||
| /// Grammar to constrain valid tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | ||||
| /// </summary> | /// </summary> | ||||
| public float RepeatPenalty { get; set; } = 1.1f; | |||||
| public float RepeatPenalty { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | ||||
| @@ -43,7 +38,7 @@ public sealed class DefaultSamplingPipeline | |||||
| _alphaFreq = value; | _alphaFreq = value; | ||||
| } | } | ||||
| } | } | ||||
| private float _alphaFreq = 0.1f; | |||||
| private float _alphaFreq; | |||||
| /// <summary> | /// <summary> | ||||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | ||||
| @@ -62,7 +57,7 @@ public sealed class DefaultSamplingPipeline | |||||
| _alphaPresence = value; | _alphaPresence = value; | ||||
| } | } | ||||
| } | } | ||||
| private float _alphaPresence = 0.1f; | |||||
| private float _alphaPresence; | |||||
| /// <summary> | /// <summary> | ||||
| /// Temperature to apply (higher temperature is more "creative") | /// Temperature to apply (higher temperature is more "creative") | ||||
| @@ -99,33 +94,46 @@ public sealed class DefaultSamplingPipeline | |||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNewline { get; set; } = false; | public bool PenalizeNewline { get; set; } = false; | ||||
| private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; | |||||
| private float[]? _logits; | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| if (PenalizeNewline) | |||||
| return Array.Empty<LLamaToken>(); | |||||
| // Skip work if possible | |||||
| if (LogitBias.Count == 0) | |||||
| return logits; | |||||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||||
| return _newlineToken; | |||||
| } | |||||
| // Create a temporary array to hold logits | |||||
| if (_logits == null || _logits.Length < logits.Length) | |||||
| _logits = new float[logits.Length]; | |||||
| /// <inheritdoc /> | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| // Copy logits | |||||
| logits.CopyTo(_logits); | |||||
| var mutable = _logits.AsSpan(0, logits.Length); | |||||
| // Apply logit bias | |||||
| foreach (var (key, value) in LogitBias) | foreach (var (key, value) in LogitBias) | ||||
| logits[key] += value; | |||||
| mutable[key] += value; | |||||
| return mutable; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | ||||
| { | { | ||||
| // Apply penalties to candidates | |||||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||||
| // Only apply repetition penalty if we really must. Otherwise avoid all this work | |||||
| if (lastTokens.Length > 0 && (RepeatPenalty != 0 || AlphaFrequency != 0 || AlphaPresence != 0)) | |||||
| { | |||||
| // Save the logit value for the newline token | |||||
| var (nlIndex, nlLogit) = PenalizeNewline ? GetNewlineLogit(ctx, candidates) : (-1, 0); | |||||
| // Apply penalties to candidates | |||||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||||
| // Restore protected tokens, so they are not affected by repetition penalties | |||||
| RestoreProtectedTokens(candidates); | |||||
| // Restore newline token | |||||
| if (!PenalizeNewline) | |||||
| SetNewlineLogit(ctx, candidates, nlIndex, nlLogit); | |||||
| } | |||||
| // Apply the normal llama.cpp pipeline | // Apply the normal llama.cpp pipeline | ||||
| candidates.ApplyGrammar(ctx, Grammar); | candidates.ApplyGrammar(ctx, Grammar); | ||||
| @@ -135,12 +143,52 @@ public sealed class DefaultSamplingPipeline | |||||
| candidates.TopP(ctx, TopP); | candidates.TopP(ctx, TopP); | ||||
| candidates.MinP(ctx, MinP); | candidates.MinP(ctx, MinP); | ||||
| candidates.Temperature(ctx, Temperature); | candidates.Temperature(ctx, Temperature); | ||||
| var id = candidates.SampleToken(ctx); | |||||
| return candidates.SampleToken(ctx); | |||||
| } | |||||
| private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||||
| { | |||||
| var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); | |||||
| // Try using the ID as an index | |||||
| if (candidates.data.Span[(int)nlToken].id == nlToken) | |||||
| return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); | |||||
| // Exhaustive search | |||||
| var span = candidates.data.Span; | |||||
| for (var i = 0; i < span.Length; i++) | |||||
| { | |||||
| if (span[i].id == nlToken) | |||||
| return (i, span[i].logit); | |||||
| } | |||||
| Grammar?.AcceptToken(ctx, id); | |||||
| return id; | |||||
| return (-1, 0); | |||||
| } | } | ||||
| private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit) | |||||
| { | |||||
| var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); | |||||
| // Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order | |||||
| if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken) | |||||
| { | |||||
| candidates.data.Span[indexHint].logit = logit; | |||||
| return; | |||||
| } | |||||
| // Didn't find it, do an exhaustive search for it | |||||
| var span = candidates.data.Span; | |||||
| for (var i = 0; i < candidates.data.Length; i++) | |||||
| { | |||||
| if (span[i].id == nlToken) | |||||
| { | |||||
| span[i].logit = logit; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | ||||
| { | { | ||||
| Grammar?.AcceptToken(ctx, token); | Grammar?.AcceptToken(ctx, token); | ||||
| @@ -0,0 +1,32 @@ | |||||
| using System; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// A sampling pipeline which always selects the most likely token | |||||
| /// </summary> | |||||
| public class GreedySamplingPipeline | |||||
| : BaseSamplingPipeline | |||||
| { | |||||
| /// <inheritdoc /> | |||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return logits; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return candidates.SampleTokenGreedy(ctx); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override ISamplingPipeline Clone() | |||||
| { | |||||
| return new GreedySamplingPipeline | |||||
| { | |||||
| Grammar = Grammar?.Clone() | |||||
| }; | |||||
| } | |||||
| } | |||||
| @@ -19,7 +19,7 @@ public interface ISamplingPipeline | |||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | /// <param name="lastTokens">A span of tokens recently returned by the model</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| /// <summary> | /// <summary> | ||||
| /// Update the pipeline, with knowledge that a particular token was just accepted | /// Update the pipeline, with knowledge that a particular token was just accepted | ||||
| @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions | |||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | /// <param name="lastTokens">A list of tokens recently returned by the model</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens) | |||||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens) | |||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| var span = CollectionsMarshal.AsSpan(lastTokens); | var span = CollectionsMarshal.AsSpan(lastTokens); | ||||
| @@ -0,0 +1,71 @@ | |||||
| using System; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// A sampling pipeline which uses mirostat (v2) to select tokens | |||||
| /// </summary> | |||||
| public class Mirostate2SamplingPipeline | |||||
| : BaseSamplingPipeline | |||||
| { | |||||
| private const float DEFAULT_TAU = 5; | |||||
| private float _mu = DEFAULT_TAU * 2; | |||||
| /// <summary> | |||||
| /// Currently learned mu value | |||||
| /// </summary> | |||||
| public float Mu => _mu; | |||||
| private float _tau = DEFAULT_TAU; | |||||
| /// <summary> | |||||
| /// target entropy | |||||
| /// </summary> | |||||
| public float Tau | |||||
| { | |||||
| get => _tau; | |||||
| set | |||||
| { | |||||
| _tau = value; | |||||
| _mu = value * 2; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// learning rate | |||||
| /// </summary> | |||||
| public float Eta { get; set; } = 0.1f; | |||||
| /// <inheritdoc /> | |||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return logits; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override void Reset() | |||||
| { | |||||
| base.Reset(); | |||||
| _mu = Tau * 2; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override ISamplingPipeline Clone() | |||||
| { | |||||
| return new Mirostate2SamplingPipeline | |||||
| { | |||||
| Grammar = Grammar?.Clone(), | |||||
| _mu = _mu, | |||||
| _tau = _tau, | |||||
| Eta = Eta | |||||
| }; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using System; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// A sampling pipeline which uses mirostat (v1) to select tokens | |||||
| /// </summary> | |||||
| public class MirostateSamplingPipeline | |||||
| : BaseSamplingPipeline | |||||
| { | |||||
| private const int MIROSTAT_M = 100; | |||||
| private const float DEFAULT_TAU = 5; | |||||
| private float _mu = DEFAULT_TAU * 2; | |||||
| /// <summary> | |||||
| /// Currently learned mu value | |||||
| /// </summary> | |||||
| public float Mu => _mu; | |||||
| private float _tau = DEFAULT_TAU; | |||||
| /// <summary> | |||||
| /// target entropy | |||||
| /// </summary> | |||||
| public float Tau | |||||
| { | |||||
| get => _tau; | |||||
| set | |||||
| { | |||||
| _tau = value; | |||||
| _mu = value * 2; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// learning rate | |||||
| /// </summary> | |||||
| public float Eta { get; set; } = 0.1f; | |||||
| /// <inheritdoc /> | |||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return logits; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | |||||
| return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override void Reset() | |||||
| { | |||||
| base.Reset(); | |||||
| _mu = Tau * 2; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override ISamplingPipeline Clone() | |||||
| { | |||||
| return new MirostateSamplingPipeline | |||||
| { | |||||
| Grammar = Grammar?.Clone(), | |||||
| _mu = _mu, | |||||
| _tau = _tau, | |||||
| Eta = Eta | |||||
| }; | |||||
| } | |||||
| } | |||||