From e81b3023d5e184983d20bd1ae72b26b3915bd0dc Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 26 Oct 2023 00:52:31 +0100 Subject: [PATCH] Rewritten sampling API to be accessed through the `LLamaTokenDataArray` object --- LLama.Examples/NewVersion/BatchedDecoding.cs | 13 +- LLama.Examples/NewVersion/TestRunner.cs | 2 +- LLama/LLamaContext.cs | 35 ++- LLama/Native/LLamaTokenDataArray.cs | 236 +++++++++++++++---- LLama/Native/NativeApi.Sampling.cs | 22 +- LLama/Native/SafeLLamaGrammarHandle.cs | 10 + LLama/Native/SamplingApi.cs | 106 ++------- 7 files changed, 259 insertions(+), 165 deletions(-) diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs index 66929310..702e8799 100644 --- a/LLama.Examples/NewVersion/BatchedDecoding.cs +++ b/LLama.Examples/NewVersion/BatchedDecoding.cs @@ -6,6 +6,10 @@ using LLama.Native; namespace LLama.Examples.NewVersion; +/// +/// This demonstrates generating multiple replies to the same prompt, with a shared cache +/// +/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this! public class BatchedDecoding { private const int n_parallel = 8; @@ -116,12 +120,11 @@ public class BatchedDecoding { candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); } - using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native); - candidates_native.TopK(context.NativeHandle, top_k); - candidates_native.TopP(context.NativeHandle, top_p); - candidates_native.Temperature(context.NativeHandle, temp); - var new_token_id = candidates_native.SampleToken(context.NativeHandle); + candidates.TopK(context.NativeHandle, top_k); + candidates.TopP(context.NativeHandle, top_p); + candidates.Temperature(context.NativeHandle, temp); + var new_token_id = candidates.SampleToken(context.NativeHandle); if (new_token_id == eos || new_token_id == nl) { diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 231a67ca..2f698f80 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -22,7 +22,7 @@ Console.WriteLine("12: Semantic Kernel Chat."); Console.WriteLine("13: Semantic Kernel Memory."); Console.WriteLine("14: Coding Assistant."); - Console.WriteLine("15: Batch Decoding Benchmark."); + Console.WriteLine("15: Batch Decoding."); while (true) { diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 2962cb69..1d0704bf 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -235,13 +235,13 @@ namespace LLama if (grammar != null) { - SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); + candidates.ApplyGrammar(NativeHandle, grammar); } if (temperature <= 0) { // Greedy sampling - id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); + id = candidates.SampleTokenGreedy(NativeHandle); } else { @@ -250,32 +250,28 @@ namespace LLama if (mirostat == MirostatType.Mirostat) { const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); } else if (mirostat == MirostatType.Mirostat2) { - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); } else { - // Temperature sampling - SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token(NativeHandle, candidates); + candidates.TopK(NativeHandle, topK); + candidates.TailFree(NativeHandle, tfsZ); + candidates.LocallyTypical(NativeHandle, typicalP); + candidates.TopP(NativeHandle, topP); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleToken(NativeHandle); } } mirostat_mu = mu; } - if (grammar != null) - { - NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); - } + grammar?.AcceptToken(NativeHandle, id); return id; } @@ -305,7 +301,7 @@ namespace LLama } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); + var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); var nl_logit = logits[nl_token]; // Convert logits into token candidates @@ -316,8 +312,7 @@ namespace LLama var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); // Apply penalties to candidates - SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); + candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); // Restore newline token logit value if necessary if (!penalizeNL) diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 22235a79..8f20a73a 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -45,6 +45,199 @@ namespace LLama.Native return new LLamaTokenDataArray(candidates); } + + #region sampling + /// + /// Apply grammar rules to candidate tokens + /// + /// + /// + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_grammar(ctx, ref st, grammar); + sorted = st.sorted; + } + } + + /// + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// + /// + /// Number of tokens to keep + /// Minimum number to keep + public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_top_k(context, ref st, k, minKeep); + sorted = st.sorted; + } + } + + /// + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// + /// + /// + /// + public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_top_p(context, ref st, p, minKeep); + sorted = st.sorted; + } + } + + /// + /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + /// + /// + /// + /// + public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); + sorted = st.sorted; + } + } + + /// + /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + /// + /// + /// + /// + public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_typical(context, ref st, p, min_keep); + sorted = st.sorted; + } + } + + /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// + /// + /// + /// + /// + /// + public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + { + unsafe + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + using (var last_tokens_handle = last_tokens.Pin()) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } + } + } + + /// + /// Sample with temperature. + /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual + /// + /// + /// + public void Temperature(SafeLLamaContextHandle context, float temp) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_temperature(context, ref st, temp); + sorted = st.sorted; + } + } + + /// + /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// + /// + public void Softmax(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_softmax(context, ref st); + sorted = st.sorted; + } + } + + /// + /// Randomly selects a token from the candidates based on their probabilities. + /// + /// + /// + public int SampleToken(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token(context, ref st); + sorted = st.sorted; + return token; + } + } + + /// + /// Selects the token with the highest probability. + /// + /// + /// + public int SampleTokenGreedy(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_greedy(context, ref st); + sorted = st.sorted; + return token; + } + } + + /// + /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + /// + public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); + sorted = st.sorted; + return token; + } + } + + /// + /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + /// + public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); + sorted = st.sorted; + return token; + } + } + #endregion } /// @@ -96,48 +289,5 @@ namespace LLama.Native return handle; } - - /// - /// Perform TopK sampling, sorting the data and reducing the size to k - /// - /// - /// Number of tokens to keep - /// Minimum number to keep - public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) - { - NativeApi.llama_sample_top_k(context, ref this, k, minKeep); - } - - /// - /// Perform top p sampling, sorting the data and keeping only logits more likely than p - /// - /// - /// - /// - public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) - { - NativeApi.llama_sample_top_p(context, ref this, p, minKeep); - } - - /// - /// Apply temperature to logits - /// - /// - /// - public void Temperature(SafeLLamaContextHandle context, float temp) - { - NativeApi.llama_sample_temperature(context, ref this, temp); - } - - /// - /// Sample a token from the set of possible tokens - /// - /// - /// - /// - public int SampleToken(SafeLLamaContextHandle context) - { - return NativeApi.llama_sample_token(context, ref this); - } } } diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 80e682cf..e7ee32ba 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -9,26 +9,22 @@ namespace LLama.Native { /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); - - /// /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// /// /// Pointer to LLamaTokenDataArray /// /// - /// - /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); + public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, + ref LLamaTokenDataArrayNative candidates, + llama_token* last_tokens, ulong last_tokens_size, + float penalty_repeat, + float penalty_freq, + float penalty_present); /// /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index ed1c15c8..f430b7c3 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -102,5 +102,15 @@ namespace LLama.Native return new(grammar_ptr); } #endregion + + /// + /// Accepts the sampled token into the grammar + /// + /// + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + NativeApi.llama_grammar_accept_token(ctx, this, token); + } } } diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index e26bf971..41709def 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -9,7 +9,7 @@ namespace LLama.Native /// /// Direct translation of the llama.cpp sampling API /// - public unsafe class SamplingApi + public class SamplingApi { /// /// Apply grammar rules to candidate tokens @@ -17,70 +17,10 @@ namespace LLama.Native /// /// /// + [Obsolete("use LLamaTokenDataArray ApplyGrammar method")] public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_grammar(ctx, ref st, grammar); - } - - /// - /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - [Obsolete("last_tokens_size parameter is no longer needed")] - public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty) - { - llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); - } - - /// - /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty) - { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - using var last_tokens_handle = last_tokens.Pin(); - - NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); - } - - /// - /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - /// - [Obsolete("last_tokens_size parameter is no longer needed")] - public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) - { - llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); - } - - /// - /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence) - { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - using var last_tokens_handle = last_tokens.Pin(); - - NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); + candidates.ApplyGrammar(ctx, grammar); } /// @@ -88,10 +28,10 @@ namespace LLama.Native /// /// /// Pointer to LLamaTokenDataArray + [Obsolete("use LLamaTokenDataArray Softmax method")] public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_softmax(ctx, ref st); + candidates.Softmax(ctx); } /// @@ -101,10 +41,10 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TopK method")] public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); + candidates.TopK(ctx, k, min_keep); } /// @@ -114,10 +54,10 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TopP method")] public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); + candidates.TopP(ctx, p, min_keep); } /// @@ -127,10 +67,10 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TailFree method")] public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); + candidates.TailFree(ctx, z, min_keep); } /// @@ -140,10 +80,10 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray LocallyTypical method")] public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); + candidates.LocallyTypical(ctx, p, min_keep); } /// @@ -153,10 +93,10 @@ namespace LLama.Native /// /// /// + [Obsolete("use LLamaTokenDataArray Temperature() method")] public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_temperature(ctx, ref st, temp); + candidates.Temperature(ctx, temp); } /// @@ -169,10 +109,10 @@ namespace LLama.Native /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// + [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); + return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); } /// @@ -184,10 +124,10 @@ namespace LLama.Native /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// + [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); + return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); } /// @@ -196,10 +136,10 @@ namespace LLama.Native /// /// Pointer to LLamaTokenDataArray /// + [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_greedy(ctx, ref st); + return candidates.SampleTokenGreedy(ctx); } /// @@ -208,10 +148,10 @@ namespace LLama.Native /// /// Pointer to LLamaTokenDataArray /// + [Obsolete("use LLamaTokenDataArray SampleToken() method")] public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token(ctx, ref st); + return candidates.SampleToken(ctx); } } }