diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs
index 66929310..702e8799 100644
--- a/LLama.Examples/NewVersion/BatchedDecoding.cs
+++ b/LLama.Examples/NewVersion/BatchedDecoding.cs
@@ -6,6 +6,10 @@ using LLama.Native;
namespace LLama.Examples.NewVersion;
+///
+/// This demonstrates generating multiple replies to the same prompt, with a shared cache
+///
+/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!
public class BatchedDecoding
{
private const int n_parallel = 8;
@@ -116,12 +120,11 @@ public class BatchedDecoding
{
candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
- using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native);
- candidates_native.TopK(context.NativeHandle, top_k);
- candidates_native.TopP(context.NativeHandle, top_p);
- candidates_native.Temperature(context.NativeHandle, temp);
- var new_token_id = candidates_native.SampleToken(context.NativeHandle);
+ candidates.TopK(context.NativeHandle, top_k);
+ candidates.TopP(context.NativeHandle, top_p);
+ candidates.Temperature(context.NativeHandle, temp);
+ var new_token_id = candidates.SampleToken(context.NativeHandle);
if (new_token_id == eos || new_token_id == nl)
{
diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs
index 231a67ca..2f698f80 100644
--- a/LLama.Examples/NewVersion/TestRunner.cs
+++ b/LLama.Examples/NewVersion/TestRunner.cs
@@ -22,7 +22,7 @@
Console.WriteLine("12: Semantic Kernel Chat.");
Console.WriteLine("13: Semantic Kernel Memory.");
Console.WriteLine("14: Coding Assistant.");
- Console.WriteLine("15: Batch Decoding Benchmark.");
+ Console.WriteLine("15: Batch Decoding.");
while (true)
{
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 2962cb69..1d0704bf 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -235,13 +235,13 @@ namespace LLama
if (grammar != null)
{
- SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
+ candidates.ApplyGrammar(NativeHandle, grammar);
}
if (temperature <= 0)
{
// Greedy sampling
- id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
+ id = candidates.SampleTokenGreedy(NativeHandle);
}
else
{
@@ -250,32 +250,28 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu);
}
else
{
- // Temperature sampling
- SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
- SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
- SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
- SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token(NativeHandle, candidates);
+ candidates.TopK(NativeHandle, topK);
+ candidates.TailFree(NativeHandle, tfsZ);
+ candidates.LocallyTypical(NativeHandle, typicalP);
+ candidates.TopP(NativeHandle, topP);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleToken(NativeHandle);
}
}
mirostat_mu = mu;
}
- if (grammar != null)
- {
- NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
- }
+ grammar?.AcceptToken(NativeHandle, id);
return id;
}
@@ -305,7 +301,7 @@ namespace LLama
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
+ var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
@@ -316,8 +312,7 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
// Apply penalties to candidates
- SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
- SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
+ candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence);
// Restore newline token logit value if necessary
if (!penalizeNL)
diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index 22235a79..8f20a73a 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -45,6 +45,199 @@ namespace LLama.Native
return new LLamaTokenDataArray(candidates);
}
+
+ #region sampling
+ ///
+ /// Apply grammar rules to candidate tokens
+ ///
+ ///
+ ///
+ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_grammar(ctx, ref st, grammar);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ ///
+ ///
+ /// Number of tokens to keep
+ /// Minimum number to keep
+ public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ ///
+ ///
+ ///
+ ///
+ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
+ ///
+ ///
+ ///
+ ///
+ public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
+ ///
+ ///
+ ///
+ ///
+ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_typical(context, ref st, p, min_keep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
+ {
+ unsafe
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ using (var last_tokens_handle = last_tokens.Pin())
+ {
+ NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
+ sorted = st.sorted;
+ }
+ }
+ }
+
+ ///
+ /// Sample with temperature.
+ /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
+ ///
+ ///
+ ///
+ public void Temperature(SafeLLamaContextHandle context, float temp)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_temperature(context, ref st, temp);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+ ///
+ ///
+ public void Softmax(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_softmax(context, ref st);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Randomly selects a token from the candidates based on their probabilities.
+ ///
+ ///
+ ///
+ public int SampleToken(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token(context, ref st);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Selects the token with the highest probability.
+ ///
+ ///
+ ///
+ public int SampleTokenGreedy(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_greedy(context, ref st);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ ///
+ ///
+ /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+ /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ ///
+ public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ ///
+ ///
+ /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ ///
+ public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+ #endregion
}
///
@@ -96,48 +289,5 @@ namespace LLama.Native
return handle;
}
-
- ///
- /// Perform TopK sampling, sorting the data and reducing the size to k
- ///
- ///
- /// Number of tokens to keep
- /// Minimum number to keep
- public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
- {
- NativeApi.llama_sample_top_k(context, ref this, k, minKeep);
- }
-
- ///
- /// Perform top p sampling, sorting the data and keeping only logits more likely than p
- ///
- ///
- ///
- ///
- public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
- {
- NativeApi.llama_sample_top_p(context, ref this, p, minKeep);
- }
-
- ///
- /// Apply temperature to logits
- ///
- ///
- ///
- public void Temperature(SafeLLamaContextHandle context, float temp)
- {
- NativeApi.llama_sample_temperature(context, ref this, temp);
- }
-
- ///
- /// Sample a token from the set of possible tokens
- ///
- ///
- ///
- ///
- public int SampleToken(SafeLLamaContextHandle context)
- {
- return NativeApi.llama_sample_token(context, ref this);
- }
}
}
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index 80e682cf..e7ee32ba 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -9,26 +9,22 @@ namespace LLama.Native
{
///
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);
-
- ///
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
///
///
/// Pointer to LLamaTokenDataArray
///
///
- ///
- ///
+ /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
+ public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
+ ref LLamaTokenDataArrayNative candidates,
+ llama_token* last_tokens, ulong last_tokens_size,
+ float penalty_repeat,
+ float penalty_freq,
+ float penalty_present);
///
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs
index ed1c15c8..f430b7c3 100644
--- a/LLama/Native/SafeLLamaGrammarHandle.cs
+++ b/LLama/Native/SafeLLamaGrammarHandle.cs
@@ -102,5 +102,15 @@ namespace LLama.Native
return new(grammar_ptr);
}
#endregion
+
+ ///
+ /// Accepts the sampled token into the grammar
+ ///
+ ///
+ ///
+ public void AcceptToken(SafeLLamaContextHandle ctx, int token)
+ {
+ NativeApi.llama_grammar_accept_token(ctx, this, token);
+ }
}
}
diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs
index e26bf971..41709def 100644
--- a/LLama/Native/SamplingApi.cs
+++ b/LLama/Native/SamplingApi.cs
@@ -9,7 +9,7 @@ namespace LLama.Native
///
/// Direct translation of the llama.cpp sampling API
///
- public unsafe class SamplingApi
+ public class SamplingApi
{
///
/// Apply grammar rules to candidate tokens
@@ -17,70 +17,10 @@ namespace LLama.Native
///
///
///
+ [Obsolete("use LLamaTokenDataArray ApplyGrammar method")]
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_grammar(ctx, ref st, grammar);
- }
-
- ///
- /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- [Obsolete("last_tokens_size parameter is no longer needed")]
- public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty)
- {
- llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
- }
-
- ///
- /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty)
- {
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- using var last_tokens_handle = last_tokens.Pin();
-
- NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
- }
-
- ///
- /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- ///
- [Obsolete("last_tokens_size parameter is no longer needed")]
- public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
- {
- llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
- }
-
- ///
- /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence)
- {
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- using var last_tokens_handle = last_tokens.Pin();
-
- NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
+ candidates.ApplyGrammar(ctx, grammar);
}
///
@@ -88,10 +28,10 @@ namespace LLama.Native
///
///
/// Pointer to LLamaTokenDataArray
+ [Obsolete("use LLamaTokenDataArray Softmax method")]
public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_softmax(ctx, ref st);
+ candidates.Softmax(ctx);
}
///
@@ -101,10 +41,10 @@ namespace LLama.Native
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TopK method")]
public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
+ candidates.TopK(ctx, k, min_keep);
}
///
@@ -114,10 +54,10 @@ namespace LLama.Native
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TopP method")]
public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
+ candidates.TopP(ctx, p, min_keep);
}
///
@@ -127,10 +67,10 @@ namespace LLama.Native
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TailFree method")]
public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
+ candidates.TailFree(ctx, z, min_keep);
}
///
@@ -140,10 +80,10 @@ namespace LLama.Native
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray LocallyTypical method")]
public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
+ candidates.LocallyTypical(ctx, p, min_keep);
}
///
@@ -153,10 +93,10 @@ namespace LLama.Native
///
///
///
+ [Obsolete("use LLamaTokenDataArray Temperature() method")]
public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_temperature(ctx, ref st, temp);
+ candidates.Temperature(ctx, temp);
}
///
@@ -169,10 +109,10 @@ namespace LLama.Native
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
+ [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
+ return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
}
///
@@ -184,10 +124,10 @@ namespace LLama.Native
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
+ [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
+ return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
}
///
@@ -196,10 +136,10 @@ namespace LLama.Native
///
/// Pointer to LLamaTokenDataArray
///
+ [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_greedy(ctx, ref st);
+ return candidates.SampleTokenGreedy(ctx);
}
///
@@ -208,10 +148,10 @@ namespace LLama.Native
///
/// Pointer to LLamaTokenDataArray
///
+ [Obsolete("use LLamaTokenDataArray SampleToken() method")]
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token(ctx, ref st);
+ return candidates.SampleToken(ctx);
}
}
}