diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs
index 9ad77531..389563aa 100644
--- a/LLama.Unittest/GrammarParserTest.cs
+++ b/LLama.Unittest/GrammarParserTest.cs
@@ -1,5 +1,4 @@
-using System.Text;
-using LLama.Exceptions;
+using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index 195cc4a2..72e9acf8 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -1,5 +1,6 @@
using System.Diagnostics;
using LLama.Common;
+using LLama.Sampling;
using Xunit.Abstractions;
namespace LLama.Unittest
@@ -30,10 +31,13 @@ namespace LLama.Unittest
[Fact]
public async Task Stateless()
{
+ // Create a custom pipeline that mimics the default pipeline
+ var pipeline = new DefaultSamplingPipeline();
+
var executor = new StatelessExecutor(_weights, _params);
const string question = "Question. what is a cat?\nAnswer: ";
- var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
+ var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
var timer = new Stopwatch();
timer.Start();
diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs
index 89d94ade..c604dc0d 100644
--- a/LLama.Web/Common/InferenceOptions.cs
+++ b/LLama.Web/Common/InferenceOptions.cs
@@ -1,6 +1,9 @@
-using LLama.Common;
+#nullable enable
+
+using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Web.Common
{
@@ -64,6 +67,9 @@ namespace LLama.Web.Common
///
/// A grammar to constrain possible tokens
///
- public SafeLLamaGrammarHandle Grammar { get; set; } = null;
+ public SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ public ISamplingPipeline? SamplingPipeline { get; set; }
}
}
diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
index d87faf0e..e1e89414 100644
--- a/LLama/Abstractions/IInferenceParams.cs
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Abstractions
{
@@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
///
SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored!
+ ///
+ ISamplingPipeline? SamplingPipeline { get; set; }
}
}
\ No newline at end of file
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index d7bd19d9..c1f39550 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Common
{
@@ -76,6 +77,9 @@ namespace LLama.Common
///
public SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ public ISamplingPipeline? SamplingPipeline { get; set; }
}
///
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 3a3e51af..2902dc8f 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
+using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@@ -212,6 +213,17 @@ namespace LLama
}
}
+ ///
+ /// Sample a single token from this context, using the given sampling pipeline
+ ///
+ /// The pipeline to use to process the logits and to select a token
+ /// The tokens recently returned from the model
+ /// The selected token
+ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens)
+ {
+ return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
+ }
+
///
/// Perform the sampling. Please don't use it unless you fully know what it does.
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index d81630aa..3ed66890 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
- var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ }
+ else
+ {
+ var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var mu = MirostatMu;
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
- MirostatMu = mu;
+ var mu = MirostatMu;
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ MirostatMu = mu;
+ }
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 4d28274b..9cecf437 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
- var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
-
- var mu = MirostatMu;
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
- MirostatMu = mu;
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ }
+ else
+ {
+ var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+
+ var mu = MirostatMu;
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ MirostatMu = mu;
+ }
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 9c41af7c..831aceb2 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
+using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@@ -85,16 +86,24 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
- // Penalize the generated tokens by various penalties
- var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
-
- // Sample a single token
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
+ }
+ else
+ {
+ // Penalize the generated tokens by various penalties
+ var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+
+ // Sample a single token
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ }
// Decode this token into text
decoder.Add(id);
diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index 4bc154f4..5059a5f3 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -46,14 +46,41 @@ namespace LLama.Native
return new LLamaTokenDataArray(candidates);
}
+ ///
+ /// Overwrite the logit values for all given tokens
+ ///
+ /// tuples of token and logit value to overwrite
+ public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
+ {
+ if (values.Length == 0)
+ return;
+
+ var dataSpan = data.Span;
+ foreach (var (token, value) in values)
+ {
+ for (var i = 0; i < data.Length; i++)
+ {
+ if (dataSpan[i].id == token)
+ {
+ dataSpan[i].logit = value;
+ break;
+ }
+ }
+ }
+ sorted = false;
+ }
+
#region sampling
///
/// Apply grammar rules to candidate tokens
///
///
///
- public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
+ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
{
+ if (grammar == null)
+ return;
+
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
@@ -145,15 +172,17 @@ namespace LLama.Native
///
///
///
- public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
+ public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
- using (var last_tokens_handle = last_tokens.Pin())
{
- NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
- sorted = st.sorted;
+ fixed (int* last_tokens_handle = last_tokens)
+ {
+ NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
+ sorted = st.sorted;
+ }
}
}
}
diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs
new file mode 100644
index 00000000..4c0f7689
--- /dev/null
+++ b/LLama/Sampling/BaseSamplingPipeline.cs
@@ -0,0 +1,128 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using LLama.Native;
+
+namespace LLama.Sampling;
+
+///
+/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
+///
+public abstract class BaseSamplingPipeline
+ : ISamplingPipeline
+{
+ private int _savedLogitsCount;
+ private (int index, float logit)[]? _savedLogits;
+
+ ///
+ public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
+ {
+ var protectedLogits = GetProtectedTokens(ctx);
+ _savedLogitsCount = protectedLogits.Count;
+ _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
+ try
+ {
+ // Save the values of protected logits
+ for (var i = 0; i < protectedLogits.Count; i++)
+ {
+ var index = protectedLogits[i];
+ var value = logits[index];
+ _savedLogits[i] = (index, value);
+ }
+
+ // Process raw logits
+ ProcessLogits(ctx, logits, lastTokens);
+
+ // Automatically restore saved logit values after processing
+ RestoreProtectedTokens(logits);
+
+ // Convert logits into token candidates
+ var candidates = LLamaTokenDataArray.Create(logits);
+
+ // Process token data array
+ ProcessTokenDataArray(ctx, candidates, lastTokens);
+
+ // Choose the final value
+ return ChooseToken(ctx, candidates);
+ }
+ finally
+ {
+ ArrayPool<(int, float)>.Shared.Return(_savedLogits);
+ _savedLogits = null;
+ _savedLogitsCount = 0;
+ }
+ }
+
+ #region protected tokens
+ ///
+ /// Get all of the "protected" tokens that cannot be changed by ProcessLogits
+ ///
+ ///
+ protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
+ ///
+ ///
+ protected void RestoreProtectedTokens(Span logits)
+ {
+ if (_savedLogits == null)
+ return;
+
+ // The array may be bigger than necessary, get a span of the valid bit
+ var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
+
+ // Restore the values of protected logits
+ for (var i = 0; i < saved.Length; i++)
+ logits[saved[i].index] = saved[i].logit;
+ }
+
+ ///
+ /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
+ ///
+ ///
+ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
+ {
+ if (_savedLogits == null || _savedLogits.Length == 0)
+ return;
+
+ candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
+ }
+ #endregion
+
+ ///
+ /// Process the raw logit values
+ ///
+ /// The context being sampled from
+ /// The logits produced by the model
+ /// A list of tokens recently returned by the model
+ protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
+
+ ///
+ /// Process the LLamaTokenDataArray and select a single token
+ ///
+ /// The context being sampled from
+ /// The LLamaTokenDataArray data produced by the model
+ /// A list of tokens recently returned by the model
+ ///
+ protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens);
+
+ ///
+ /// Choose the final token from the candidates
+ ///
+ ///
+ ///
+ ///
+ protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
+
+ ///
+ public virtual void Reset()
+ {
+ }
+
+ ///
+ public virtual void Dispose()
+ {
+ GC.SuppressFinalize(this);
+ }
+}
\ No newline at end of file
diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs
new file mode 100644
index 00000000..e6db2efe
--- /dev/null
+++ b/LLama/Sampling/DefaultSamplingPipeline.cs
@@ -0,0 +1,149 @@
+using System;
+using System.Collections.Generic;
+using LLama.Extensions;
+using LLama.Native;
+
+namespace LLama.Sampling;
+
+///
+/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
+///
+public sealed class DefaultSamplingPipeline
+ : BaseSamplingPipeline
+{
+ ///
+ /// Bias values to add to certain logits
+ ///
+ public Dictionary LogitBias { get; } = new();
+
+ ///
+ /// Grammar to constrain valid tokens
+ ///
+ public SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
+ ///
+ public float RepeatPenalty { get; set; } = 1.1f;
+
+ ///
+ /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
+ /// so far, decreasing the model's likelihood to repeat the same line verbatim.
+ ///
+ public float AlphaFrequency
+ {
+ get => _alphaFreq;
+ set
+ {
+ if (value < -2)
+ throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
+ if (value > 2)
+ throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
+ _alphaFreq = value;
+ }
+ }
+ private float _alphaFreq = 0.1f;
+
+ ///
+ /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
+ /// text so far, increasing the model's likelihood to talk about new topics.
+ ///
+ public float AlphaPresence
+ {
+ get => _alphaPresence;
+ set
+ {
+ if (value < -2)
+ throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
+ if (value > 2)
+ throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
+ _alphaPresence = value;
+ }
+ }
+ private float _alphaPresence = 0.1f;
+
+ ///
+ /// Temperature to apply (higher temperature is more "creative")
+ ///
+ public float Temperature { get; set; } = 0.75f;
+
+ ///
+ /// Number of tokens to keep in TopK sampling
+ ///
+ public int TopK { get; set; }
+
+ ///
+ /// Z value for tail free sampling
+ ///
+ public float TailFreeZ { get; set; }
+
+ ///
+ /// P value for locally typical sampling
+ ///
+ public float TypicalP { get; set; }
+
+ ///
+ /// P value for TopP sampling
+ ///
+ public float TopP { get; set; } = 1f;
+
+ ///
+ /// P value for MinP sampling
+ ///
+ public float MinP { get; set; }
+
+ ///
+ /// Whether the newline value should be protected from being modified by logit bias and repeat penalty
+ ///
+ public bool PenalizeNewline { get; set; } = false;
+
+ private readonly int[] _newlineToken = new int[1];
+
+ ///
+ protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx)
+ {
+ if (PenalizeNewline)
+ return Array.Empty();
+
+ _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
+ return _newlineToken;
+ }
+
+ ///
+ protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens)
+ {
+ foreach (var (key, value) in LogitBias)
+ logits[key] += value;
+ }
+
+ ///
+ protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens)
+ {
+ // Apply penalties to candidates
+ candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
+
+ // Restore protected tokens, so they are not affected by repetition penalties
+ RestoreProtectedTokens(candidates);
+
+ // Apply the normal llama.cpp pipeline
+ candidates.ApplyGrammar(ctx, Grammar);
+ candidates.TopK(ctx, TopK);
+ candidates.TailFree(ctx, TailFreeZ);
+ candidates.LocallyTypical(ctx, TypicalP);
+ candidates.TopP(ctx, TopP);
+ candidates.MinP(ctx, MinP);
+ candidates.Temperature(ctx, Temperature);
+ var id = candidates.SampleToken(ctx);
+
+ Grammar?.AcceptToken(ctx, id);
+ return id;
+ }
+
+ ///
+ protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
+ {
+ return candidates.SampleToken(ctx);
+ }
+}
\ No newline at end of file
diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs
new file mode 100644
index 00000000..f39bf996
--- /dev/null
+++ b/LLama/Sampling/ISamplingPipeline.cs
@@ -0,0 +1,61 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using LLama.Native;
+
+namespace LLama.Sampling;
+
+///
+/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
+///
+public interface ISamplingPipeline
+ : IDisposable
+{
+ ///
+ /// Sample a single token from the given logits
+ ///
+ /// The context being sampled from
+ /// The logits produced by the model
+ /// A span of tokens recently returned by the model
+ ///
+ int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
+
+ ///
+ /// Reset all internal state of the sampling pipeline
+ ///
+ void Reset();
+}
+
+///
+/// Extensions methods for ISamplingPipeline
+///
+public static class ISamplingPipelineExtensions
+{
+ ///
+ /// Sample a single token from the given logits
+ ///
+ ///
+ /// The context being sampled from
+ /// The logits produced by the model
+ /// A list of tokens recently returned by the model
+ ///
+ public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens)
+ {
+#if NET5_0_OR_GREATER
+ var span = CollectionsMarshal.AsSpan(lastTokens);
+ return pipeline.Sample(ctx, logits, span);
+#else
+ var copy = ArrayPool.Shared.Rent(lastTokens.Count);
+ try
+ {
+ lastTokens.CopyTo(copy);
+ return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(copy);
+ }
+#endif
+ }
+}
\ No newline at end of file