using System;
using System.Collections.Generic;
using LLama.Extensions;
using LLama.Native;
namespace LLama.Sampling;
///
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
///
public sealed class DefaultSamplingPipeline
: BaseSamplingPipeline
{
///
/// Bias values to add to certain logits
///
public Dictionary LogitBias { get; } = new();
///
/// Grammar to constrain valid tokens
///
public SafeLLamaGrammarHandle? Grammar { get; set; }
///
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
///
public float RepeatPenalty { get; set; } = 1.1f;
///
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
///
public float AlphaFrequency
{
get => _alphaFreq;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
}
}
private float _alphaFreq = 0.1f;
///
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
///
public float AlphaPresence
{
get => _alphaPresence;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
}
}
private float _alphaPresence = 0.1f;
///
/// Temperature to apply (higher temperature is more "creative")
///
public float Temperature { get; set; } = 0.75f;
///
/// Number of tokens to keep in TopK sampling
///
public int TopK { get; set; }
///
/// Z value for tail free sampling
///
public float TailFreeZ { get; set; }
///
/// P value for locally typical sampling
///
public float TypicalP { get; set; }
///
/// P value for TopP sampling
///
public float TopP { get; set; } = 1f;
///
/// P value for MinP sampling
///
public float MinP { get; set; }
///
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
///
public bool PenalizeNewline { get; set; } = false;
private readonly int[] _newlineToken = new int[1];
///
protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx)
{
if (PenalizeNewline)
return Array.Empty();
_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}
///
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
{
foreach (var (key, value) in LogitBias)
logits[key] += value;
}
///
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);
// Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar);
candidates.TopK(ctx, TopK);
candidates.TailFree(ctx, TailFreeZ);
candidates.LocallyTypical(ctx, TypicalP);
candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);
Grammar?.AcceptToken(ctx, id);
return id;
}
///
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}
}