Custom Sampling Pipelinestags/0.9.1
| @@ -1,5 +1,4 @@ | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Grammars; | using LLama.Grammars; | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Sampling; | |||||
| using Xunit.Abstractions; | using Xunit.Abstractions; | ||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| @@ -30,10 +31,13 @@ namespace LLama.Unittest | |||||
| [Fact] | [Fact] | ||||
| public async Task Stateless() | public async Task Stateless() | ||||
| { | { | ||||
| // Create a custom pipeline that mimics the default pipeline | |||||
| var pipeline = new DefaultSamplingPipeline(); | |||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| const string question = "Question. what is a cat?\nAnswer: "; | const string question = "Question. what is a cat?\nAnswer: "; | ||||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | |||||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | |||||
| var timer = new Stopwatch(); | var timer = new Stopwatch(); | ||||
| timer.Start(); | timer.Start(); | ||||
| @@ -1,6 +1,9 @@ | |||||
| using LLama.Common; | |||||
| #nullable enable | |||||
| using LLama.Common; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| @@ -64,6 +67,9 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// A grammar to constrain possible tokens | /// A grammar to constrain possible tokens | ||||
| /// </summary> | /// </summary> | ||||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| /// <inheritdoc /> | |||||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -108,5 +109,10 @@ namespace LLama.Abstractions | |||||
| /// Grammar to constrain possible tokens | /// Grammar to constrain possible tokens | ||||
| /// </summary> | /// </summary> | ||||
| SafeLLamaGrammarHandle? Grammar { get; set; } | SafeLLamaGrammarHandle? Grammar { get; set; } | ||||
| /// <summary> | |||||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | |||||
| /// </summary> | |||||
| ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -76,6 +77,9 @@ namespace LLama.Common | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | public SafeLLamaGrammarHandle? Grammar { get; set; } | ||||
| /// <inheritdoc /> | |||||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -10,6 +10,7 @@ using LLama.Common; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Sampling; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -212,6 +213,17 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Sample a single token from this context, using the given sampling pipeline | |||||
| /// </summary> | |||||
| /// <param name="pipeline">The pipeline to use to process the logits and to select a token</param> | |||||
| /// <param name="lastTokens">The tokens recently returned from the model</param> | |||||
| /// <returns>The selected token</returns> | |||||
| public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens) | |||||
| { | |||||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Perform the sampling. Please don't use it unless you fully know what it does. | /// Perform the sampling. Please don't use it unless you fully know what it does. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -210,16 +210,24 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| var mu = MirostatMu; | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| } | |||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| @@ -189,16 +189,24 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| } | |||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -85,16 +86,24 @@ namespace LLama | |||||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | ||||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | ||||
| { | { | ||||
| // Penalize the generated tokens by various penalties | |||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| // Sample a single token | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Penalize the generated tokens by various penalties | |||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| // Sample a single token | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| } | |||||
| // Decode this token into text | // Decode this token into text | ||||
| decoder.Add(id); | decoder.Add(id); | ||||
| @@ -46,14 +46,41 @@ namespace LLama.Native | |||||
| return new LLamaTokenDataArray(candidates); | return new LLamaTokenDataArray(candidates); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Overwrite the logit values for all given tokens | |||||
| /// </summary> | |||||
| /// <param name="values">tuples of token and logit value to overwrite</param> | |||||
| public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) | |||||
| { | |||||
| if (values.Length == 0) | |||||
| return; | |||||
| var dataSpan = data.Span; | |||||
| foreach (var (token, value) in values) | |||||
| { | |||||
| for (var i = 0; i < data.Length; i++) | |||||
| { | |||||
| if (dataSpan[i].id == token) | |||||
| { | |||||
| dataSpan[i].logit = value; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| sorted = false; | |||||
| } | |||||
| #region sampling | #region sampling | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply grammar rules to candidate tokens | /// Apply grammar rules to candidate tokens | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="grammar"></param> | /// <param name="grammar"></param> | ||||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) | |||||
| { | { | ||||
| if (grammar == null) | |||||
| return; | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | using (LLamaTokenDataArrayNative.Create(this, out var st)) | ||||
| { | { | ||||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | NativeApi.llama_sample_grammar(ctx, ref st, grammar); | ||||
| @@ -145,15 +172,17 @@ namespace LLama.Native | |||||
| /// <param name="penalty_repeat"></param> | /// <param name="penalty_repeat"></param> | ||||
| /// <param name="penalty_freq"></param> | /// <param name="penalty_freq"></param> | ||||
| /// <param name="penalty_present"></param> | /// <param name="penalty_present"></param> | ||||
| public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||||
| public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | using (LLamaTokenDataArrayNative.Create(this, out var st)) | ||||
| using (var last_tokens_handle = last_tokens.Pin()) | |||||
| { | { | ||||
| NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||||
| sorted = st.sorted; | |||||
| fixed (int* last_tokens_handle = last_tokens) | |||||
| { | |||||
| NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,128 @@ | |||||
| using System; | |||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. | |||||
| /// </summary> | |||||
| public abstract class BaseSamplingPipeline | |||||
| : ISamplingPipeline | |||||
| { | |||||
| private int _savedLogitsCount; | |||||
| private (int index, float logit)[]? _savedLogits; | |||||
| /// <inheritdoc/> | |||||
| public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||||
| { | |||||
| var protectedLogits = GetProtectedTokens(ctx); | |||||
| _savedLogitsCount = protectedLogits.Count; | |||||
| _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); | |||||
| try | |||||
| { | |||||
| // Save the values of protected logits | |||||
| for (var i = 0; i < protectedLogits.Count; i++) | |||||
| { | |||||
| var index = protectedLogits[i]; | |||||
| var value = logits[index]; | |||||
| _savedLogits[i] = (index, value); | |||||
| } | |||||
| // Process raw logits | |||||
| ProcessLogits(ctx, logits, lastTokens); | |||||
| // Automatically restore saved logit values after processing | |||||
| RestoreProtectedTokens(logits); | |||||
| // Convert logits into token candidates | |||||
| var candidates = LLamaTokenDataArray.Create(logits); | |||||
| // Process token data array | |||||
| ProcessTokenDataArray(ctx, candidates, lastTokens); | |||||
| // Choose the final value | |||||
| return ChooseToken(ctx, candidates); | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<(int, float)>.Shared.Return(_savedLogits); | |||||
| _savedLogits = null; | |||||
| _savedLogitsCount = 0; | |||||
| } | |||||
| } | |||||
| #region protected tokens | |||||
| /// <summary> | |||||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx); | |||||
| /// <summary> | |||||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||||
| /// </summary> | |||||
| /// <param name="logits"></param> | |||||
| protected void RestoreProtectedTokens(Span<float> logits) | |||||
| { | |||||
| if (_savedLogits == null) | |||||
| return; | |||||
| // The array may be bigger than necessary, get a span of the valid bit | |||||
| var saved = _savedLogits.AsSpan(0, _savedLogitsCount); | |||||
| // Restore the values of protected logits | |||||
| for (var i = 0; i < saved.Length; i++) | |||||
| logits[saved[i].index] = saved[i].logit; | |||||
| } | |||||
| /// <summary> | |||||
| /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits | |||||
| /// </summary> | |||||
| /// <param name="candidates"></param> | |||||
| protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) | |||||
| { | |||||
| if (_savedLogits == null || _savedLogits.Length == 0) | |||||
| return; | |||||
| candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); | |||||
| } | |||||
| #endregion | |||||
| /// <summary> | |||||
| /// Process the raw logit values | |||||
| /// </summary> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="logits">The logits produced by the model</param> | |||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||||
| /// <summary> | |||||
| /// Process the LLamaTokenDataArray and select a single token | |||||
| /// </summary> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="candidates">The LLamaTokenDataArray data produced by the model</param> | |||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||||
| /// <returns></returns> | |||||
| protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens); | |||||
| /// <summary> | |||||
| /// Choose the final token from the candidates | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <returns></returns> | |||||
| protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||||
| /// <inheritdoc/> | |||||
| public virtual void Reset() | |||||
| { | |||||
| } | |||||
| /// <inheritdoc/> | |||||
| public virtual void Dispose() | |||||
| { | |||||
| GC.SuppressFinalize(this); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,149 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Extensions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// An implementation of ISamplePipeline which mimics the default llama.cpp sampling | |||||
| /// </summary> | |||||
| public sealed class DefaultSamplingPipeline | |||||
| : BaseSamplingPipeline | |||||
| { | |||||
| /// <summary> | |||||
| /// Bias values to add to certain logits | |||||
| /// </summary> | |||||
| public Dictionary<int, float> LogitBias { get; } = new(); | |||||
| /// <summary> | |||||
| /// Grammar to constrain valid tokens | |||||
| /// </summary> | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| /// <summary> | |||||
| /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 | |||||
| /// </summary> | |||||
| public float RepeatPenalty { get; set; } = 1.1f; | |||||
| /// <summary> | |||||
| /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text | |||||
| /// so far, decreasing the model's likelihood to repeat the same line verbatim. | |||||
| /// </summary> | |||||
| public float AlphaFrequency | |||||
| { | |||||
| get => _alphaFreq; | |||||
| set | |||||
| { | |||||
| if (value < -2) | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||||
| if (value > 2) | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||||
| _alphaFreq = value; | |||||
| } | |||||
| } | |||||
| private float _alphaFreq = 0.1f; | |||||
| /// <summary> | |||||
| /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br /> | |||||
| /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the | |||||
| /// text so far, increasing the model's likelihood to talk about new topics. | |||||
| /// </summary> | |||||
| public float AlphaPresence | |||||
| { | |||||
| get => _alphaPresence; | |||||
| set | |||||
| { | |||||
| if (value < -2) | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); | |||||
| if (value > 2) | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); | |||||
| _alphaPresence = value; | |||||
| } | |||||
| } | |||||
| private float _alphaPresence = 0.1f; | |||||
| /// <summary> | |||||
| /// Temperature to apply (higher temperature is more "creative") | |||||
| /// </summary> | |||||
| public float Temperature { get; set; } = 0.75f; | |||||
| /// <summary> | |||||
| /// Number of tokens to keep in TopK sampling | |||||
| /// </summary> | |||||
| public int TopK { get; set; } | |||||
| /// <summary> | |||||
| /// Z value for tail free sampling | |||||
| /// </summary> | |||||
| public float TailFreeZ { get; set; } | |||||
| /// <summary> | |||||
| /// P value for locally typical sampling | |||||
| /// </summary> | |||||
| public float TypicalP { get; set; } | |||||
| /// <summary> | |||||
| /// P value for TopP sampling | |||||
| /// </summary> | |||||
| public float TopP { get; set; } = 1f; | |||||
| /// <summary> | |||||
| /// P value for MinP sampling | |||||
| /// </summary> | |||||
| public float MinP { get; set; } | |||||
| /// <summary> | |||||
| /// Whether the newline value should be protected from being modified by logit bias and repeat penalty | |||||
| /// </summary> | |||||
| public bool PenalizeNewline { get; set; } = false; | |||||
| private readonly int[] _newlineToken = new int[1]; | |||||
| /// <inheritdoc /> | |||||
| protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx) | |||||
| { | |||||
| if (PenalizeNewline) | |||||
| return Array.Empty<int>(); | |||||
| _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); | |||||
| return _newlineToken; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens) | |||||
| { | |||||
| foreach (var (key, value) in LogitBias) | |||||
| logits[key] += value; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens) | |||||
| { | |||||
| // Apply penalties to candidates | |||||
| candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); | |||||
| // Restore protected tokens, so they are not affected by repetition penalties | |||||
| RestoreProtectedTokens(candidates); | |||||
| // Apply the normal llama.cpp pipeline | |||||
| candidates.ApplyGrammar(ctx, Grammar); | |||||
| candidates.TopK(ctx, TopK); | |||||
| candidates.TailFree(ctx, TailFreeZ); | |||||
| candidates.LocallyTypical(ctx, TypicalP); | |||||
| candidates.TopP(ctx, TopP); | |||||
| candidates.MinP(ctx, MinP); | |||||
| candidates.Temperature(ctx, Temperature); | |||||
| var id = candidates.SampleToken(ctx); | |||||
| Grammar?.AcceptToken(ctx, id); | |||||
| return id; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||||
| { | |||||
| return candidates.SampleToken(ctx); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,61 @@ | |||||
| using System; | |||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using LLama.Native; | |||||
| namespace LLama.Sampling; | |||||
| /// <summary> | |||||
| /// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. | |||||
| /// </summary> | |||||
| public interface ISamplingPipeline | |||||
| : IDisposable | |||||
| { | |||||
| /// <summary> | |||||
| /// Sample a single token from the given logits | |||||
| /// </summary> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="logits">The logits produced by the model</param> | |||||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||||
| /// <returns></returns> | |||||
| int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||||
| /// <summary> | |||||
| /// Reset all internal state of the sampling pipeline | |||||
| /// </summary> | |||||
| void Reset(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Extensions methods for ISamplingPipeline | |||||
| /// </summary> | |||||
| public static class ISamplingPipelineExtensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Sample a single token from the given logits | |||||
| /// </summary> | |||||
| /// <param name="pipeline"></param> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="logits">The logits produced by the model</param> | |||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||||
| /// <returns></returns> | |||||
| public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens) | |||||
| { | |||||
| #if NET5_0_OR_GREATER | |||||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||||
| return pipeline.Sample(ctx, logits, span); | |||||
| #else | |||||
| var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count); | |||||
| try | |||||
| { | |||||
| lastTokens.CopyTo(copy); | |||||
| return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<int>.Shared.Return(copy); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||