using System; using System.Collections.Generic; using LLama.Extensions; using LLama.Native; namespace LLama.Sampling; /// /// An implementation of ISamplePipeline which mimics the default llama.cpp sampling /// public sealed class DefaultSamplingPipeline : BaseSamplingPipeline { /// /// Bias values to add to certain logits /// public Dictionary LogitBias { get; } = new(); /// /// Grammar to constrain valid tokens /// public SafeLLamaGrammarHandle? Grammar { get; set; } /// /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 /// public float RepeatPenalty { get; set; } = 1.1f; /// /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text /// so far, decreasing the model's likelihood to repeat the same line verbatim. ///
public float AlphaFrequency { get => _alphaFreq; set { if (value < -2) throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); if (value > 2) throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); _alphaFreq = value; } } private float _alphaFreq = 0.1f; /// /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the /// text so far, increasing the model's likelihood to talk about new topics. ///
public float AlphaPresence { get => _alphaPresence; set { if (value < -2) throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); if (value > 2) throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); _alphaPresence = value; } } private float _alphaPresence = 0.1f; /// /// Temperature to apply (higher temperature is more "creative") /// public float Temperature { get; set; } = 0.75f; /// /// Number of tokens to keep in TopK sampling /// public int TopK { get; set; } /// /// Z value for tail free sampling /// public float TailFreeZ { get; set; } /// /// P value for locally typical sampling /// public float TypicalP { get; set; } /// /// P value for TopP sampling /// public float TopP { get; set; } = 1f; /// /// P value for MinP sampling /// public float MinP { get; set; } /// /// Whether the newline value should be protected from being modified by logit bias and repeat penalty /// public bool PenalizeNewline { get; set; } = false; private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; /// protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) { if (PenalizeNewline) return Array.Empty(); _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); return _newlineToken; } /// protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { foreach (var (key, value) in LogitBias) logits[key] += value; } /// protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { // Apply penalties to candidates candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); // Restore protected tokens, so they are not affected by repetition penalties RestoreProtectedTokens(candidates); // Apply the normal llama.cpp pipeline candidates.ApplyGrammar(ctx, Grammar); candidates.TopK(ctx, TopK); candidates.TailFree(ctx, TailFreeZ); candidates.LocallyTypical(ctx, TypicalP); candidates.TopP(ctx, TopP); candidates.MinP(ctx, MinP); candidates.Temperature(ctx, Temperature); var id = candidates.SampleToken(ctx); Grammar?.AcceptToken(ctx, id); return id; } public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) { Grammar?.AcceptToken(ctx, token); } /// public override ISamplingPipeline Clone() { var clone = new DefaultSamplingPipeline(); foreach (var (k, v) in LogitBias) clone.LogitBias.Add(k, v); clone.Grammar = Grammar?.Clone(); clone.RepeatPenalty = RepeatPenalty; clone.AlphaFrequency = AlphaFrequency; clone.AlphaPresence = AlphaPresence; clone.Temperature = Temperature; clone.TopK = TopK; clone.TailFreeZ = TailFreeZ; clone.TypicalP = TypicalP; clone.TopP = TopP; clone.MinP = MinP; clone.PenalizeNewline = PenalizeNewline; return clone; } }