using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native;
namespace LLama.Sampling;
///
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
///
public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;
///
public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index];
_savedLogits[i] = (index, value);
}
// Process raw logits
ProcessLogits(ctx, logits, lastTokens);
// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);
// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);
// Process token data array
ProcessTokenDataArray(ctx, candidates, lastTokens);
// Choose the final value
return ChooseToken(ctx, candidates);
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}
#region protected tokens
///
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
///
///
protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx);
///
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
///
///
protected void RestoreProtectedTokens(Span logits)
{
if (_savedLogits == null)
return;
// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
}
///
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
///
///
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
{
if (_savedLogits == null || _savedLogits.Length == 0)
return;
candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
}
#endregion
///
/// Process the raw logit values
///
/// The context being sampled from
/// The logits produced by the model
/// A list of tokens recently returned by the model
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
///
/// Process the LLamaTokenDataArray and select a single token
///
/// The context being sampled from
/// The LLamaTokenDataArray data produced by the model
/// A list of tokens recently returned by the model
///
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens);
///
/// Choose the final token from the candidates
///
///
///
///
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
///
public virtual void Reset()
{
}
///
public virtual void Dispose()
{
GC.SuppressFinalize(this);
}
}