using System;
using LLama.Native;
namespace LLama.Sampling;
///
/// A sampling pipeline which uses mirostat (v2) to select tokens
///
public class Mirostate2SamplingPipeline
: BaseSamplingPipeline
{
private const float DEFAULT_TAU = 5;
private float _mu = DEFAULT_TAU * 2;
///
/// Currently learned mu value
///
public float Mu => _mu;
private float _tau = DEFAULT_TAU;
///
/// target entropy
///
public float Tau
{
get => _tau;
set
{
_tau = value;
_mu = value * 2;
}
}
///
/// learning rate
///
public float Eta { get; set; } = 0.1f;
///
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
{
}
///
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens)
{
return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
}
///
public override void Reset()
{
base.Reset();
_mu = Tau * 2;
}
///
public override ISamplingPipeline Clone()
{
return new Mirostate2SamplingPipeline
{
Grammar = Grammar?.Clone(),
_mu = _mu,
_tau = _tau,
Eta = Eta
};
}
}