using System; using System.Buffers; using System.Collections.Generic; using LLama.Native; namespace LLama.Sampling; /// /// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. /// public abstract class BaseSamplingPipeline : ISamplingPipeline { private int _savedLogitsCount; private (LLamaToken index, float logit)[]? _savedLogits; /// public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { var protectedLogits = GetProtectedTokens(ctx); _savedLogitsCount = protectedLogits.Count; _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); try { // Save the values of protected logits for (var i = 0; i < protectedLogits.Count; i++) { var index = protectedLogits[i]; var value = logits[(int)index]; _savedLogits[i] = (index, value); } // Process raw logits ProcessLogits(ctx, logits, lastTokens); // Automatically restore saved logit values after processing RestoreProtectedTokens(logits); // Convert logits into token candidates var candidates = LLamaTokenDataArray.Create(logits); // Process token data array return ProcessTokenDataArray(ctx, candidates, lastTokens); } finally { ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); _savedLogits = null; _savedLogitsCount = 0; } } /// public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); #region protected tokens /// /// Get all of the "protected" tokens that cannot be changed by ProcessLogits /// /// protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); /// /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits /// /// protected void RestoreProtectedTokens(Span logits) { if (_savedLogits == null) return; // The array may be bigger than necessary, get a span of the valid bit var saved = _savedLogits.AsSpan(0, _savedLogitsCount); // Restore the values of protected logits for (var i = 0; i < saved.Length; i++) logits[(int)saved[i].index] = saved[i].logit; } /// /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits /// /// protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) { if (_savedLogits == null || _savedLogits.Length == 0) return; candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); } #endregion /// /// Process the raw logit values /// /// The context being sampled from /// The logits produced by the model /// A list of tokens recently returned by the model protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Process the LLamaTokenDataArray and select a single token /// /// The context being sampled from /// The LLamaTokenDataArray data produced by the model /// A list of tokens recently returned by the model /// protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); /// public virtual void Reset() { } /// public abstract ISamplingPipeline Clone(); /// public virtual void Dispose() { GC.SuppressFinalize(this); } }