using System; using LLama.Native; namespace LLama.Sampling; /// /// A sampling pipeline which uses mirostat (v2) to select tokens /// public class Mirostate2SamplingPipeline : BaseSamplingPipeline { private const float DEFAULT_TAU = 5; private float _mu = DEFAULT_TAU * 2; /// /// Currently learned mu value /// public float Mu => _mu; private float _tau = DEFAULT_TAU; /// /// target entropy /// public float Tau { get => _tau; set { _tau = value; _mu = value * 2; } } /// /// learning rate /// public float Eta { get; set; } = 0.1f; /// protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { } /// protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); } /// public override void Reset() { base.Reset(); _mu = Tau * 2; } /// public override ISamplingPipeline Clone() { return new Mirostate2SamplingPipeline { Grammar = Grammar?.Clone(), _mu = _mu, _tau = _tau, Eta = Eta }; } }