using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
using LLama.Sampling.Tokens;
namespace LLama.Sampling;
///
/// Convert a span of logits into a single sampled token
///
public interface ISamplingPipeline
: IDisposable
{
///
/// Sample a single token from the given logits
///
/// The context being sampled from
/// The logits produced by the model
/// A span of tokens recently returned by the model
///
int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
///
/// Reset all internal state of the sampling pipeline
///
void Reset();
}
///
/// Extensions methods for ISamplingPipeline
///
public static class ISamplingPipelineExtensions
{
///
/// Sample a single token from the given logits
///
///
/// The context being sampled from
/// The logits produced by the model
/// A list of tokens recently returned by the model
///
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
}
finally
{
ArrayPool.Shared.Return(copy);
}
#endif
}
}
///
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
///
public sealed class ConfigurableSamplingPipeline
: ISamplingPipeline
{
///
/// Logit processors to apply in this pipeline
///
public IList LogitProcessors { get; } = new List();
///
/// Token data processors to apply in this pipeline
///
public IList TokenDataProcessors { get; } = new List();
///
/// The selector to choose the final token
///
public ITokenSelector Selector { get; set; } = new StandardSelection();
///
public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
{
// Modify raw logits
foreach (var logitProcessor in LogitProcessors)
logitProcessor.ProcessLogits(ctx, logits, lastTokens);
// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
// Process token candidates
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens);
// Select a token
var token = Selector.Select(ctx, candidates_p, lastTokens);
// Tell processors what was selected
foreach (var logitProcessor in LogitProcessors)
logitProcessor.AcceptToken(ctx, token);
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.AcceptToken(ctx, token);
return token;
}
///
public void Reset()
{
foreach (var logitProcessor in LogitProcessors)
logitProcessor.Reset();
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.Reset();
Selector.Reset();
}
///
public void Dispose()
{
foreach (var logitProcessor in LogitProcessors)
logitProcessor.Dispose();
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.Dispose();
Selector.Dispose();
}
}