using System; using System.Buffers; using System.Collections.Generic; using System.Runtime.InteropServices; using LLama.Native; using LLama.Sampling.Logits; using LLama.Sampling.Selection; using LLama.Sampling.Tokens; namespace LLama.Sampling; /// /// Convert a span of logits into a single sampled token /// public interface ISamplingPipeline : IDisposable { /// /// Sample a single token from the given logits /// /// The context being sampled from /// The logits produced by the model /// A span of tokens recently returned by the model /// int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Reset all internal state of the sampling pipeline /// void Reset(); } /// /// Extensions methods for ISamplingPipeline /// public static class ISamplingPipelineExtensions { /// /// Sample a single token from the given logits /// /// /// The context being sampled from /// The logits produced by the model /// A list of tokens recently returned by the model /// public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) { #if NET5_0_OR_GREATER var span = CollectionsMarshal.AsSpan(lastTokens); return pipeline.Sample(ctx, logits, span); #else var copy = ArrayPool.Shared.Rent(lastTokens.Count); try { lastTokens.CopyTo(copy); return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); } finally { ArrayPool.Shared.Return(copy); } #endif } } /// /// Simple implementation of `ISamplingPipeline`, applies processors in order every time /// public sealed class ConfigurableSamplingPipeline : ISamplingPipeline { /// /// Logit processors to apply in this pipeline /// public IList LogitProcessors { get; } = new List(); /// /// Token data processors to apply in this pipeline /// public IList TokenDataProcessors { get; } = new List(); /// /// The selector to choose the final token /// public ITokenSelector Selector { get; set; } = new StandardSelection(); /// public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { // Modify raw logits foreach (var logitProcessor in LogitProcessors) logitProcessor.ProcessLogits(ctx, logits, lastTokens); // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); // Process token candidates foreach (var tokenDataProcessor in TokenDataProcessors) tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); // Select a token var token = Selector.Select(ctx, candidates_p, lastTokens); // Tell processors what was selected foreach (var logitProcessor in LogitProcessors) logitProcessor.AcceptToken(ctx, token); foreach (var tokenDataProcessor in TokenDataProcessors) tokenDataProcessor.AcceptToken(ctx, token); return token; } /// public void Reset() { foreach (var logitProcessor in LogitProcessors) logitProcessor.Reset(); foreach (var tokenDataProcessor in TokenDataProcessors) tokenDataProcessor.Reset(); Selector.Reset(); } /// public void Dispose() { foreach (var logitProcessor in LogitProcessors) logitProcessor.Dispose(); foreach (var tokenDataProcessor in TokenDataProcessors) tokenDataProcessor.Dispose(); Selector.Dispose(); } }