using System; using LLama.Native; namespace LLama.Sampling; /// /// A sampling pipeline which uses mirostat (v1) to select tokens /// public class MirostateSamplingPipeline : BaseSamplingPipeline { private const int MIROSTAT_M = 100; private const float DEFAULT_TAU = 5; private float _mu = DEFAULT_TAU * 2; /// /// Currently learned mu value /// public float Mu => _mu; private float _tau = DEFAULT_TAU; /// /// target entropy /// public float Tau { get => _tau; set { _tau = value; _mu = value * 2; } } /// /// learning rate /// public float Eta { get; set; } = 0.1f; /// protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { } /// protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu); } /// public override void Reset() { base.Reset(); _mu = Tau * 2; } /// public override ISamplingPipeline Clone() { return new MirostateSamplingPipeline { Grammar = Grammar?.Clone(), _mu = _mu, _tau = _tau, Eta = Eta }; } }