diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 9ad77531..389563aa 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -1,5 +1,4 @@ -using System.Text; -using LLama.Exceptions; +using LLama.Exceptions; using LLama.Native; using LLama.Grammars; diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 195cc4a2..72e9acf8 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using LLama.Common; +using LLama.Sampling; using Xunit.Abstractions; namespace LLama.Unittest @@ -30,10 +31,13 @@ namespace LLama.Unittest [Fact] public async Task Stateless() { + // Create a custom pipeline that mimics the default pipeline + var pipeline = new DefaultSamplingPipeline(); + var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 89d94ade..c604dc0d 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -1,6 +1,9 @@ -using LLama.Common; +#nullable enable + +using LLama.Common; using LLama.Abstractions; using LLama.Native; +using LLama.Sampling; namespace LLama.Web.Common { @@ -64,6 +67,9 @@ namespace LLama.Web.Common /// /// A grammar to constrain possible tokens /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index d87faf0e..e1e89414 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Abstractions { @@ -108,5 +109,10 @@ namespace LLama.Abstractions /// Grammar to constrain possible tokens /// SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! + /// + ISamplingPipeline? SamplingPipeline { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d7bd19d9..c1f39550 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Sampling; namespace LLama.Common { @@ -76,6 +77,9 @@ namespace LLama.Common /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } /// diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3a3e51af..2902dc8f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,6 +10,7 @@ using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; using LLama.Abstractions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -212,6 +213,17 @@ namespace LLama } } + /// + /// Sample a single token from this context, using the given sampling pipeline + /// + /// The pipeline to use to process the logits and to select a token + /// The tokens recently returned from the model + /// The selected token + public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + { + return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + } + /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d81630aa..3ed66890 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -210,16 +210,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4d28274b..9cecf437 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -189,16 +189,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 9c41af7c..831aceb2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Native; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -85,16 +86,24 @@ namespace LLama var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - // Sample a single token - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + } + else + { + // Penalize the generated tokens by various penalties + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + // Sample a single token + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + } // Decode this token into text decoder.Add(id); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 4bc154f4..5059a5f3 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -46,14 +46,41 @@ namespace LLama.Native return new LLamaTokenDataArray(candidates); } + /// + /// Overwrite the logit values for all given tokens + /// + /// tuples of token and logit value to overwrite + public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) + { + if (values.Length == 0) + return; + + var dataSpan = data.Span; + foreach (var (token, value) in values) + { + for (var i = 0; i < data.Length; i++) + { + if (dataSpan[i].id == token) + { + dataSpan[i].logit = value; + break; + } + } + } + sorted = false; + } + #region sampling /// /// Apply grammar rules to candidate tokens /// /// /// - public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) { + if (grammar == null) + return; + using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); @@ -145,15 +172,17 @@ namespace LLama.Native /// /// /// - public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) - using (var last_tokens_handle = last_tokens.Pin()) { - NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); - sorted = st.sorted; + fixed (int* last_tokens_handle = last_tokens) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } } } } diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs new file mode 100644 index 00000000..4c0f7689 --- /dev/null +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -0,0 +1,128 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. +/// +public abstract class BaseSamplingPipeline + : ISamplingPipeline +{ + private int _savedLogitsCount; + private (int index, float logit)[]? _savedLogits; + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + var protectedLogits = GetProtectedTokens(ctx); + _savedLogitsCount = protectedLogits.Count; + _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < protectedLogits.Count; i++) + { + var index = protectedLogits[i]; + var value = logits[index]; + _savedLogits[i] = (index, value); + } + + // Process raw logits + ProcessLogits(ctx, logits, lastTokens); + + // Automatically restore saved logit values after processing + RestoreProtectedTokens(logits); + + // Convert logits into token candidates + var candidates = LLamaTokenDataArray.Create(logits); + + // Process token data array + ProcessTokenDataArray(ctx, candidates, lastTokens); + + // Choose the final value + return ChooseToken(ctx, candidates); + } + finally + { + ArrayPool<(int, float)>.Shared.Return(_savedLogits); + _savedLogits = null; + _savedLogitsCount = 0; + } + } + + #region protected tokens + /// + /// Get all of the "protected" tokens that cannot be changed by ProcessLogits + /// + /// + protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(Span logits) + { + if (_savedLogits == null) + return; + + // The array may be bigger than necessary, get a span of the valid bit + var saved = _savedLogits.AsSpan(0, _savedLogitsCount); + + // Restore the values of protected logits + for (var i = 0; i < saved.Length; i++) + logits[saved[i].index] = saved[i].logit; + } + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) + { + if (_savedLogits == null || _savedLogits.Length == 0) + return; + + candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); + } + #endregion + + /// + /// Process the raw logit values + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Process the LLamaTokenDataArray and select a single token + /// + /// The context being sampled from + /// The LLamaTokenDataArray data produced by the model + /// A list of tokens recently returned by the model + /// + protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Choose the final token from the candidates + /// + /// + /// + /// + protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); + + /// + public virtual void Reset() + { + } + + /// + public virtual void Dispose() + { + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs new file mode 100644 index 00000000..e6db2efe --- /dev/null +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using LLama.Extensions; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling +/// +public sealed class DefaultSamplingPipeline + : BaseSamplingPipeline +{ + /// + /// Bias values to add to certain logits + /// + public Dictionary LogitBias { get; } = new(); + + /// + /// Grammar to constrain valid tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + private float _alphaFreq = 0.1f; + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + private float _alphaPresence = 0.1f; + + /// + /// Temperature to apply (higher temperature is more "creative") + /// + public float Temperature { get; set; } = 0.75f; + + /// + /// Number of tokens to keep in TopK sampling + /// + public int TopK { get; set; } + + /// + /// Z value for tail free sampling + /// + public float TailFreeZ { get; set; } + + /// + /// P value for locally typical sampling + /// + public float TypicalP { get; set; } + + /// + /// P value for TopP sampling + /// + public float TopP { get; set; } = 1f; + + /// + /// P value for MinP sampling + /// + public float MinP { get; set; } + + /// + /// Whether the newline value should be protected from being modified by logit bias and repeat penalty + /// + public bool PenalizeNewline { get; set; } = false; + + private readonly int[] _newlineToken = new int[1]; + + /// + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + { + if (PenalizeNewline) + return Array.Empty(); + + _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); + return _newlineToken; + } + + /// + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var (key, value) in LogitBias) + logits[key] += value; + } + + /// + protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + // Apply penalties to candidates + candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + + // Restore protected tokens, so they are not affected by repetition penalties + RestoreProtectedTokens(candidates); + + // Apply the normal llama.cpp pipeline + candidates.ApplyGrammar(ctx, Grammar); + candidates.TopK(ctx, TopK); + candidates.TailFree(ctx, TailFreeZ); + candidates.LocallyTypical(ctx, TypicalP); + candidates.TopP(ctx, TopP); + candidates.MinP(ctx, MinP); + candidates.Temperature(ctx, Temperature); + var id = candidates.SampleToken(ctx); + + Grammar?.AcceptToken(ctx, id); + return id; + } + + /// + protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + { + return candidates.SampleToken(ctx); + } +} \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs new file mode 100644 index 00000000..f39bf996 --- /dev/null +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -0,0 +1,61 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. +/// +public interface ISamplingPipeline + : IDisposable +{ + /// + /// Sample a single token from the given logits + /// + /// The context being sampled from + /// The logits produced by the model + /// A span of tokens recently returned by the model + /// + int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Reset all internal state of the sampling pipeline + /// + void Reset(); +} + +/// +/// Extensions methods for ISamplingPipeline +/// +public static class ISamplingPipelineExtensions +{ + /// + /// Sample a single token from the given logits + /// + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + /// + public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(lastTokens); + return pipeline.Sample(ctx, logits, span); +#else + var copy = ArrayPool.Shared.Rent(lastTokens.Count); + try + { + lastTokens.CopyTo(copy); + return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); + } + finally + { + ArrayPool.Shared.Return(copy); + } +#endif + } +} \ No newline at end of file