- Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline` - Added `DefaultSamplingPipeline` which mimics normal llama.cpp samplingtags/0.9.1
| @@ -1,5 +1,4 @@ | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Grammars; | |||
| @@ -1,9 +1,6 @@ | |||
| using System.Diagnostics; | |||
| using LLama.Common; | |||
| using LLama.Sampling; | |||
| using LLama.Sampling.Logits; | |||
| using LLama.Sampling.Selection; | |||
| using LLama.Sampling.Tokens; | |||
| using Xunit.Abstractions; | |||
| namespace LLama.Unittest | |||
| @@ -35,40 +32,12 @@ namespace LLama.Unittest | |||
| public async Task Stateless() | |||
| { | |||
| // Create a custom pipeline that mimics the default pipeline | |||
| var pipeline = new ConfigurableSamplingPipeline() | |||
| { | |||
| ProtectedLogits = | |||
| { | |||
| _weights.NewlineToken, | |||
| _weights.BeginningOfSentenceToken, | |||
| _weights.EndOfSentenceToken | |||
| }, | |||
| LogitProcessors = | |||
| { | |||
| new LogitBias | |||
| { | |||
| Biases = | |||
| { | |||
| { _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing! | |||
| { 42, 0f }, | |||
| } | |||
| } | |||
| }, | |||
| TokenDataProcessors = | |||
| { | |||
| new TailFreeSampling { Z = 1 }, | |||
| new LocallyTypicalSampling { P = 1 }, | |||
| new TopPSampling { P = 0.95f }, | |||
| new MinPSampling { P = 0.05f }, | |||
| new TemperatureSampling { Temperature = 0.8f }, | |||
| }, | |||
| Selector = new StandardSelection(), | |||
| }; | |||
| var pipeline = new DefaultSamplingPipeline(); | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| const string question = "Question. what is a cat?\nAnswer: "; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline}; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | |||
| var timer = new Stopwatch(); | |||
| timer.Start(); | |||
| @@ -46,14 +46,41 @@ namespace LLama.Native | |||
| return new LLamaTokenDataArray(candidates); | |||
| } | |||
| /// <summary> | |||
| /// Overwrite the logit values for all given tokens | |||
| /// </summary> | |||
| /// <param name="values">tuples of token and logit value to overwrite</param> | |||
| public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) | |||
| { | |||
| if (values.Length == 0) | |||
| return; | |||
| var dataSpan = data.Span; | |||
| foreach (var (token, value) in values) | |||
| { | |||
| for (var i = 0; i < data.Length; i++) | |||
| { | |||
| if (dataSpan[i].id == token) | |||
| { | |||
| dataSpan[i].logit = value; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| sorted = false; | |||
| } | |||
| #region sampling | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="grammar"></param> | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) | |||
| { | |||
| if (grammar == null) | |||
| return; | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| @@ -0,0 +1,128 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. | |||
| /// </summary> | |||
| public abstract class BaseSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| private int _savedLogitsCount; | |||
| private (int index, float logit)[]? _savedLogits; | |||
| /// <inheritdoc/> | |||
| public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| var protectedLogits = GetProtectedTokens(ctx); | |||
| _savedLogitsCount = protectedLogits.Count; | |||
| _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); | |||
| try | |||
| { | |||
| // Save the values of protected logits | |||
| for (var i = 0; i < protectedLogits.Count; i++) | |||
| { | |||
| var index = protectedLogits[i]; | |||
| var value = logits[index]; | |||
| _savedLogits[i] = (index, value); | |||
| } | |||
| // Process raw logits | |||
| ProcessLogits(ctx, logits, lastTokens); | |||
| // Automatically restore saved logit values after processing | |||
| RestoreProtectedTokens(logits); | |||
| // Convert logits into token candidates | |||
| var candidates = LLamaTokenDataArray.Create(logits); | |||
| // Process token data array | |||
| ProcessTokenDataArray(ctx, candidates, lastTokens); | |||
| // Choose the final value | |||
| return ChooseToken(ctx, candidates); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<(int, float)>.Shared.Return(_savedLogits); | |||
| _savedLogits = null; | |||
| _savedLogitsCount = 0; | |||
| } | |||
| } | |||
| #region protected tokens | |||
| /// <summary> | |||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="logits"></param> | |||
| protected void RestoreProtectedTokens(Span<float> logits) | |||
| { | |||
| if (_savedLogits == null) | |||
| return; | |||
| // The array may be bigger than necessary, get a span of the valid bit | |||
| var saved = _savedLogits.AsSpan(0, _savedLogitsCount); | |||
| // Restore the values of protected logits | |||
| for (var i = 0; i < saved.Length; i++) | |||
| logits[saved[i].index] = saved[i].logit; | |||
| } | |||
| /// <summary> | |||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||
| /// </summary> | |||
| /// <param name="candidates"></param> | |||
| protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) | |||
| { | |||
| if (_savedLogits == null || _savedLogits.Length == 0) | |||
| return; | |||
| candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Process the raw logit values | |||
| /// </summary> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Process the LLamaTokenDataArray and select a single token | |||
| /// </summary> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Choose the final token from the candidates | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <returns></returns> | |||
| protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||
| /// <inheritdoc/> | |||
| public virtual void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc/> | |||
| public virtual void Dispose() | |||
| { | |||
| GC.SuppressFinalize(this); | |||
| } | |||
| } | |||
| @@ -0,0 +1,149 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Extensions; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// An implementation of ISamplePipeline which mimics the default llama.cpp sampling | |||
| /// </summary> | |||
| public sealed class DefaultSamplingPipeline | |||
| : BaseSamplingPipeline | |||
| { | |||
| /// <summary> | |||
| /// Bias values to add to certain logits | |||
| /// </summary> | |||
| public Dictionary<int, float> LogitBias { get; } = new(); | |||
| /// <summary> | |||
| /// Grammar to constrain valid tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text | |||
| /// so far, decreasing the model's likelihood to repeat the same line verbatim. | |||
| /// </summary> | |||
| public float AlphaFrequency | |||
| { | |||
| get => _alphaFreq; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaFreq = value; | |||
| } | |||
| } | |||
| private float _alphaFreq = 0.1f; | |||
| /// <summary> | |||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the | |||
| /// text so far, increasing the model's likelihood to talk about new topics. | |||
| /// </summary> | |||
| public float AlphaPresence | |||
| { | |||
| get => _alphaPresence; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaPresence = value; | |||
| } | |||
| } | |||
| private float _alphaPresence = 0.1f; | |||
| /// <summary> | |||
| /// Temperature to apply (higher temperature is more "creative") | |||
| /// </summary> | |||
| public float Temperature { get; set; } = 0.75f; | |||
| /// <summary> | |||
| /// Number of tokens to keep in TopK sampling | |||
| /// </summary> | |||
| public int TopK { get; set; } | |||
| /// <summary> | |||
| /// Z value for tail free sampling | |||
| /// </summary> | |||
| public float TailFreeZ { get; set; } | |||
| /// <summary> | |||
| /// P value for locally typical sampling | |||
| /// </summary> | |||
| public float TypicalP { get; set; } | |||
| /// <summary> | |||
| /// P value for TopP sampling | |||
| /// </summary> | |||
| public float TopP { get; set; } = 1f; | |||
| /// <summary> | |||
| /// P value for MinP sampling | |||
| /// </summary> | |||
| public float MinP { get; set; } | |||
| /// <summary> | |||
| /// Whether the newline value should be protected from being modified by logit bias and repeat penalty | |||
| /// </summary> | |||
| public bool PenalizeNewline { get; set; } = false; | |||
| private readonly int[] _newlineToken = new int[1]; | |||
| /// <inheritdoc /> | |||
| protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||
| { | |||
| if (PenalizeNewline) | |||
| return Array.Empty<int>(); | |||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||
| return _newlineToken; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| foreach (var (key, value) in LogitBias) | |||
| logits[key] += value; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| // Apply penalties to candidates | |||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| // Restore protected tokens, so they are not affected by repetition penalties | |||
| RestoreProtectedTokens(candidates); | |||
| // Apply the normal llama.cpp pipeline | |||
| candidates.ApplyGrammar(ctx, Grammar); | |||
| candidates.TopK(ctx, TopK); | |||
| candidates.TailFree(ctx, TailFreeZ); | |||
| candidates.LocallyTypical(ctx, TypicalP); | |||
| candidates.TopP(ctx, TopP); | |||
| candidates.MinP(ctx, MinP); | |||
| candidates.Temperature(ctx, Temperature); | |||
| var id = candidates.SampleToken(ctx); | |||
| Grammar?.AcceptToken(ctx, id); | |||
| return id; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| } | |||
| @@ -3,14 +3,11 @@ using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using LLama.Native; | |||
| using LLama.Sampling.Logits; | |||
| using LLama.Sampling.Selection; | |||
| using LLama.Sampling.Tokens; | |||
| namespace LLama.Sampling; | |||
| /// <summary> | |||
| /// Convert a span of logits into a single sampled token | |||
| /// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. | |||
| /// </summary> | |||
| public interface ISamplingPipeline | |||
| : IDisposable | |||
| @@ -61,101 +58,4 @@ public static class ISamplingPipelineExtensions | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Simple implementation of `ISamplingPipeline`, applies processors in order every time | |||
| /// </summary> | |||
| public sealed class ConfigurableSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| /// <summary> | |||
| /// Logit processors to apply in this pipeline | |||
| /// </summary> | |||
| public IList<ILogitProcessor> LogitProcessors { get; } = new List<ILogitProcessor>(); | |||
| /// <summary> | |||
| /// Logits values which will not be changed by the logit processors | |||
| /// </summary> | |||
| public IList<int> ProtectedLogits { get; } = new List<int>(); | |||
| /// <summary> | |||
| /// Token data processors to apply in this pipeline | |||
| /// </summary> | |||
| public IList<ITokenDataProcessor> TokenDataProcessors { get; } = new List<ITokenDataProcessor>(); | |||
| /// <summary> | |||
| /// The selector to choose the final token | |||
| /// </summary> | |||
| public ITokenSelector Selector { get; set; } = new StandardSelection(); | |||
| /// <inheritdoc /> | |||
| public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| var savedLogitsCount = ProtectedLogits.Count; | |||
| var savedLogitValues = ArrayPool<float>.Shared.Rent(savedLogitsCount); | |||
| var savedLogitIndices = ArrayPool<int>.Shared.Rent(savedLogitsCount); | |||
| try | |||
| { | |||
| // Save the values of protected logits | |||
| for (var i = 0; i < ProtectedLogits.Count; i++) | |||
| { | |||
| savedLogitValues[i] = logits[ProtectedLogits[i]]; | |||
| savedLogitIndices[i] = ProtectedLogits[i]; | |||
| } | |||
| // Modify raw logits | |||
| foreach (var logitProcessor in LogitProcessors) | |||
| logitProcessor.ProcessLogits(ctx, logits, lastTokens); | |||
| // Restore the values of protected logits | |||
| for (var i = 0; i < savedLogitsCount; i++) | |||
| logits[savedLogitIndices[i]] = savedLogitValues[i]; | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<float>.Shared.Return(savedLogitValues); | |||
| ArrayPool<int>.Shared.Return(savedLogitIndices); | |||
| } | |||
| // Convert logits into token candidates | |||
| var candidates_p = LLamaTokenDataArray.Create(logits); | |||
| // Process token candidates | |||
| foreach (var tokenDataProcessor in TokenDataProcessors) | |||
| tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); | |||
| // Select a token | |||
| var token = Selector.Select(ctx, candidates_p, lastTokens); | |||
| // Tell processors what was selected | |||
| foreach (var logitProcessor in LogitProcessors) | |||
| logitProcessor.AcceptToken(ctx, token); | |||
| foreach (var tokenDataProcessor in TokenDataProcessors) | |||
| tokenDataProcessor.AcceptToken(ctx, token); | |||
| return token; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| foreach (var logitProcessor in LogitProcessors) | |||
| logitProcessor.Reset(); | |||
| foreach (var tokenDataProcessor in TokenDataProcessors) | |||
| tokenDataProcessor.Reset(); | |||
| Selector.Reset(); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| foreach (var logitProcessor in LogitProcessors) | |||
| logitProcessor.Dispose(); | |||
| foreach (var tokenDataProcessor in TokenDataProcessors) | |||
| tokenDataProcessor.Dispose(); | |||
| Selector.Dispose(); | |||
| } | |||
| } | |||
| @@ -1,34 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Logits; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Processes raw logits before sampling, applying penalties to certain tokens | |||
| /// </summary> | |||
| public interface ILogitProcessor | |||
| : IDisposable | |||
| { | |||
| /// <summary> | |||
| /// Process raw logits, indexed by llama_token | |||
| /// </summary> | |||
| /// <param name="ctx">The context this is operating in</param> | |||
| /// <param name="logits">The token data array to process</param> | |||
| /// <param name="lastTokens">The most recent tokens output</param> | |||
| /// <returns>LLamaTokenDataArray, created from logits</returns> | |||
| void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<llama_token> lastTokens); | |||
| /// <summary> | |||
| /// Inform this process when a token is accepted by the model | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="token"></param> | |||
| void AcceptToken(SafeLLamaContextHandle ctx, int token); | |||
| /// <summary> | |||
| /// Reset all internal sampling state | |||
| /// </summary> | |||
| void Reset(); | |||
| } | |||
| @@ -1,39 +0,0 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Logits; | |||
| /// <summary> | |||
| /// Add a bias directly to logit values | |||
| /// </summary> | |||
| public sealed class LogitBias | |||
| : ILogitProcessor | |||
| { | |||
| /// <summary> | |||
| /// Biases to apply, token -> bias | |||
| /// </summary> | |||
| public IDictionary<int, float> Biases { get; } = new Dictionary<int, float>(); | |||
| /// <inheritdoc /> | |||
| public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| foreach (var kvp in Biases) | |||
| logits[kvp.Key] += kvp.Value; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,27 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Selection; | |||
| /// <summary> | |||
| /// Select the most likely token | |||
| /// </summary> | |||
| public sealed class GreedySelection | |||
| : ITokenSelector | |||
| { | |||
| /// <inheritdoc /> | |||
| public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| return candidates.SampleTokenGreedy(ctx); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,25 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Selection; | |||
| /// <summary> | |||
| /// Select a single token from a set of possibilities | |||
| /// </summary> | |||
| public interface ITokenSelector | |||
| : IDisposable | |||
| { | |||
| /// <summary> | |||
| /// Select a single token | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="lastTokens"></param> | |||
| /// <returns></returns> | |||
| int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens); | |||
| /// <summary> | |||
| /// Reset the state | |||
| /// </summary> | |||
| void Reset(); | |||
| } | |||
| @@ -1,65 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Selection; | |||
| /// <summary> | |||
| /// Select a token using Mirostat sampling. | |||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// </summary> | |||
| public sealed class Mirostat2Selection | |||
| : ITokenSelector | |||
| { | |||
| private float _mu; | |||
| /// <summary> | |||
| /// Current value of Mu, updated based on the difference between target surprise and actual surprise | |||
| /// </summary> | |||
| public float Mu | |||
| { | |||
| get => _mu; | |||
| set => _mu = value; | |||
| } | |||
| /// <summary> | |||
| /// The target cross-entropy (or surprise) value you want to achieve for the generated text. | |||
| /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text | |||
| /// </summary> | |||
| public float Tau { get; set; } | |||
| /// <summary> | |||
| /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. | |||
| /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. | |||
| /// </summary> | |||
| public float Eta { get; set; } | |||
| /// <summary> | |||
| /// Create a new Mirostat 2.0 sampler | |||
| /// </summary> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. | |||
| /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. | |||
| /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| public Mirostat2Selection(float tau, float eta) | |||
| { | |||
| Tau = tau; | |||
| Eta = eta; | |||
| } | |||
| /// <inheritdoc /> | |||
| public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| _mu = 2 * Tau; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,76 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Selection; | |||
| /// <summary> | |||
| /// Select a token using Mirostat sampling. | |||
| /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// </summary> | |||
| public sealed class MirostatSelection | |||
| : ITokenSelector | |||
| { | |||
| private float _mu; | |||
| /// <summary> | |||
| /// Current value of Mu, updated based on the difference between target surprise and actual surprise | |||
| /// </summary> | |||
| public float Mu | |||
| { | |||
| get => _mu; | |||
| set => _mu = value; | |||
| } | |||
| /// <summary> | |||
| /// The target cross-entropy (or surprise) value you want to achieve for the generated text. | |||
| /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text | |||
| /// </summary> | |||
| public float Tau { get; set; } | |||
| /// <summary> | |||
| /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. | |||
| /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. | |||
| /// </summary> | |||
| public float Eta { get; set; } | |||
| /// <summary> | |||
| /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn | |||
| /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects | |||
| /// the performance of the algorithm. | |||
| /// </summary> | |||
| public int M { get; set; } | |||
| /// <summary> | |||
| /// Create a new Mirostat 2.0 sampler | |||
| /// </summary> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. | |||
| /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. | |||
| /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn | |||
| /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects | |||
| /// the performance of the algorithm.</param> | |||
| public MirostatSelection(float tau, float eta, int m = 100) | |||
| { | |||
| Tau = tau; | |||
| Eta = eta; | |||
| M = m; | |||
| } | |||
| /// <inheritdoc /> | |||
| public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| _mu = 2 * Tau; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,27 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Selection; | |||
| /// <summary> | |||
| /// Select from all possible tokens according to their probability | |||
| /// </summary> | |||
| public sealed class StandardSelection | |||
| : ITokenSelector | |||
| { | |||
| /// <inheritdoc /> | |||
| public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,59 +0,0 @@ | |||
| using System; | |||
| using LLama.Grammars; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Apply a grammar to prevent sampling tokens which do not match the grammar | |||
| /// </summary> | |||
| public sealed class GrammarSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| private SafeLLamaGrammarHandle? _handle; | |||
| /// <summary> | |||
| /// Grammar to use for sampling | |||
| /// </summary> | |||
| public Grammar? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Create a new | |||
| /// </summary> | |||
| /// <param name="grammar"></param> | |||
| public GrammarSampling(Grammar grammar) | |||
| { | |||
| Grammar = grammar; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| _handle?.Dispose(); | |||
| _handle = null; | |||
| } | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| // Create a new grammar instance if necessary | |||
| _handle ??= Grammar?.CreateInstance(); | |||
| // Apply it | |||
| if (_handle != null) | |||
| tokens.ApplyGrammar(ctx, _handle); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| _handle?.AcceptToken(ctx, token); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| _handle?.Dispose(); | |||
| _handle = null; | |||
| } | |||
| } | |||
| @@ -1,34 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Processes token logits before sampling, applying penalties to certain tokens | |||
| /// </summary> | |||
| public interface ITokenDataProcessor | |||
| : IDisposable | |||
| { | |||
| /// <summary> | |||
| /// Process token logits in a LLamaTokenDataArray | |||
| /// </summary> | |||
| /// <param name="ctx">The context this is operating in</param> | |||
| /// <param name="tokens">The token data array to process</param> | |||
| /// <param name="lastTokens">The most recent tokens output</param> | |||
| /// <returns>LLamaTokenDataArray, created from logits</returns> | |||
| void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<llama_token> lastTokens); | |||
| /// <summary> | |||
| /// Inform this process when a token is accepted by the model | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="token"></param> | |||
| void AcceptToken(SafeLLamaContextHandle ctx, int token); | |||
| /// <summary> | |||
| /// Reset all internal sampling state | |||
| /// </summary> | |||
| void Reset(); | |||
| } | |||
| @@ -1,42 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |||
| /// </summary> | |||
| public sealed class LocallyTypicalSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// P value for locally typical sampling | |||
| /// </summary> | |||
| public float P { get; set; } | |||
| /// <summary> | |||
| /// Minimum number of tokens to keep | |||
| /// </summary> | |||
| public ulong MinKeep { get; set; } = 1; | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.LocallyTypical(ctx, P, MinKeep); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,42 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | |||
| /// </summary> | |||
| public sealed class MinPSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// All tokens with probability greater than this will be kept | |||
| /// </summary> | |||
| public float P { get; set; } | |||
| /// <summary> | |||
| /// Minimum number of tokens to keep | |||
| /// </summary> | |||
| public ulong MinKeep { get; set; } = 1; | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.MinP(ctx, P, MinKeep); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,77 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| public sealed class RepetitionPenalty | |||
| : ITokenDataProcessor | |||
| { | |||
| private float _alphaFreq; | |||
| private float _alphaPresence; | |||
| /// <summary> | |||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | |||
| /// </summary> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text | |||
| /// so far, decreasing the model's likelihood to repeat the same line verbatim. | |||
| /// </summary> | |||
| public float AlphaFrequency | |||
| { | |||
| get => _alphaFreq; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaFreq = value; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the | |||
| /// text so far, increasing the model's likelihood to talk about new topics. | |||
| /// </summary> | |||
| public float AlphaPresence | |||
| { | |||
| get => _alphaPresence; | |||
| set | |||
| { | |||
| if (value < -2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||
| if (value > 2) | |||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||
| _alphaPresence = value; | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,42 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |||
| /// </summary> | |||
| public sealed class TailFreeSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// Z value for tail free sampling | |||
| /// </summary> | |||
| public float Z { get; set; } | |||
| /// <summary> | |||
| /// Minimum number of tokens to keep | |||
| /// </summary> | |||
| public ulong MinKeep { get; set; } = 1; | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.TailFree(ctx, Z, MinKeep); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,38 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Sample with temperature. | |||
| /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual | |||
| /// </summary> | |||
| public sealed class TemperatureSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// Temperature value to apply | |||
| /// </summary> | |||
| public float Temperature { get; set; } = 0.5f; | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.Temperature(ctx, Temperature); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,38 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Sample with TopK, removing all by the K most likely tokens. | |||
| /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| public sealed class TopKSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// Number of tokens to keep | |||
| /// </summary> | |||
| public int Count { get; set; } | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.TopK(ctx, Count); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -1,42 +0,0 @@ | |||
| using System; | |||
| using LLama.Native; | |||
| namespace LLama.Sampling.Tokens; | |||
| /// <summary> | |||
| /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| public sealed class TopPSampling | |||
| : ITokenDataProcessor | |||
| { | |||
| /// <summary> | |||
| /// P valies for TopP | |||
| /// </summary> | |||
| public float P { get; set; } | |||
| /// <summary> | |||
| /// Minimum number of tokens to keep | |||
| /// </summary> | |||
| public ulong MinKeep { get; set; } = 1; | |||
| /// <inheritdoc /> | |||
| public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens) | |||
| { | |||
| tokens.TopP(ctx, P, MinKeep); | |||
| } | |||
| /// <inheritdoc /> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Reset() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||