| @@ -6,6 +6,10 @@ using LLama.Native; | |||
| namespace LLama.Examples.NewVersion; | |||
| /// <summary> | |||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||
| /// </summary> | |||
| /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks> | |||
| public class BatchedDecoding | |||
| { | |||
| private const int n_parallel = 8; | |||
| @@ -116,12 +120,11 @@ public class BatchedDecoding | |||
| { | |||
| candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | |||
| } | |||
| using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native); | |||
| candidates_native.TopK(context.NativeHandle, top_k); | |||
| candidates_native.TopP(context.NativeHandle, top_p); | |||
| candidates_native.Temperature(context.NativeHandle, temp); | |||
| var new_token_id = candidates_native.SampleToken(context.NativeHandle); | |||
| candidates.TopK(context.NativeHandle, top_k); | |||
| candidates.TopP(context.NativeHandle, top_p); | |||
| candidates.Temperature(context.NativeHandle, temp); | |||
| var new_token_id = candidates.SampleToken(context.NativeHandle); | |||
| if (new_token_id == eos || new_token_id == nl) | |||
| { | |||
| @@ -22,7 +22,7 @@ | |||
| Console.WriteLine("12: Semantic Kernel Chat."); | |||
| Console.WriteLine("13: Semantic Kernel Memory."); | |||
| Console.WriteLine("14: Coding Assistant."); | |||
| Console.WriteLine("15: Batch Decoding Benchmark."); | |||
| Console.WriteLine("15: Batch Decoding."); | |||
| while (true) | |||
| { | |||
| @@ -235,13 +235,13 @@ namespace LLama | |||
| if (grammar != null) | |||
| { | |||
| SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); | |||
| candidates.ApplyGrammar(NativeHandle, grammar); | |||
| } | |||
| if (temperature <= 0) | |||
| { | |||
| // Greedy sampling | |||
| id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); | |||
| id = candidates.SampleTokenGreedy(NativeHandle); | |||
| } | |||
| else | |||
| { | |||
| @@ -250,32 +250,28 @@ namespace LLama | |||
| if (mirostat == MirostatType.Mirostat) | |||
| { | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| } | |||
| else if (mirostat == MirostatType.Mirostat2) | |||
| { | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); | |||
| } | |||
| else | |||
| { | |||
| // Temperature sampling | |||
| SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(NativeHandle, candidates); | |||
| candidates.TopK(NativeHandle, topK); | |||
| candidates.TailFree(NativeHandle, tfsZ); | |||
| candidates.LocallyTypical(NativeHandle, typicalP); | |||
| candidates.TopP(NativeHandle, topP); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleToken(NativeHandle); | |||
| } | |||
| } | |||
| mirostat_mu = mu; | |||
| } | |||
| if (grammar != null) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); | |||
| } | |||
| grammar?.AcceptToken(NativeHandle, id); | |||
| return id; | |||
| } | |||
| @@ -305,7 +301,7 @@ namespace LLama | |||
| } | |||
| // Save the newline logit value | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||
| var nl_logit = logits[nl_token]; | |||
| // Convert logits into token candidates | |||
| @@ -316,8 +312,7 @@ namespace LLama | |||
| var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | |||
| // Apply penalties to candidates | |||
| SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); | |||
| SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||
| candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); | |||
| // Restore newline token logit value if necessary | |||
| if (!penalizeNL) | |||
| @@ -45,6 +45,199 @@ namespace LLama.Native | |||
| return new LLamaTokenDataArray(candidates); | |||
| } | |||
| #region sampling | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="grammar"></param> | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="k">Number of tokens to keep</param> | |||
| /// <param name="minKeep">Minimum number to keep</param> | |||
| public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_top_k(context, ref st, k, minKeep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p"></param> | |||
| /// <param name="minKeep"></param> | |||
| public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_top_p(context, ref st, p, minKeep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="z"></param> | |||
| /// <param name="min_keep"></param> | |||
| public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_typical(context, ref st, p, min_keep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="penalty_repeat"></param> | |||
| /// <param name="penalty_freq"></param> | |||
| /// <param name="penalty_present"></param> | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| { | |||
| unsafe | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| using (var last_tokens_handle = last_tokens.Pin()) | |||
| { | |||
| NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sample with temperature. | |||
| /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="temp"></param> | |||
| public void Temperature(SafeLLamaContextHandle context, float temp) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_temperature(context, ref st, temp); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| public void Softmax(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_softmax(context, ref st); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Randomly selects a token from the candidates based on their probabilities. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleToken(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token(context, ref st); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Selects the token with the highest probability. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleTokenGreedy(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_greedy(context, ref st); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| /// <summary> | |||
| @@ -96,48 +289,5 @@ namespace LLama.Native | |||
| return handle; | |||
| } | |||
| /// <summary> | |||
| /// Perform TopK sampling, sorting the data and reducing the size to k | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="k">Number of tokens to keep</param> | |||
| /// <param name="minKeep">Minimum number to keep</param> | |||
| public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) | |||
| { | |||
| NativeApi.llama_sample_top_k(context, ref this, k, minKeep); | |||
| } | |||
| /// <summary> | |||
| /// Perform top p sampling, sorting the data and keeping only logits more likely than p | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p"></param> | |||
| /// <param name="minKeep"></param> | |||
| public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||
| { | |||
| NativeApi.llama_sample_top_p(context, ref this, p, minKeep); | |||
| } | |||
| /// <summary> | |||
| /// Apply temperature to logits | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="temp"></param> | |||
| public void Temperature(SafeLLamaContextHandle context, float temp) | |||
| { | |||
| NativeApi.llama_sample_temperature(context, ref this, temp); | |||
| } | |||
| /// <summary> | |||
| /// Sample a token from the set of possible tokens | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="NotImplementedException"></exception> | |||
| public int SampleToken(SafeLLamaContextHandle context) | |||
| { | |||
| return NativeApi.llama_sample_token(context, ref this); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,26 +9,22 @@ namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| /// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param> | |||
| /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||
| public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||
| ref LLamaTokenDataArrayNative candidates, | |||
| llama_token* last_tokens, ulong last_tokens_size, | |||
| float penalty_repeat, | |||
| float penalty_freq, | |||
| float penalty_present); | |||
| /// <summary> | |||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | |||
| @@ -102,5 +102,15 @@ namespace LLama.Native | |||
| return new(grammar_ptr); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Accepts the sampled token into the grammar | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="token"></param> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(ctx, this, token); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Direct translation of the llama.cpp sampling API | |||
| /// </summary> | |||
| public unsafe class SamplingApi | |||
| public class SamplingApi | |||
| { | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| @@ -17,70 +17,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="grammar"></param> | |||
| [Obsolete("use LLamaTokenDataArray ApplyGrammar method")] | |||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty) | |||
| { | |||
| llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="penalty"></param> | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); | |||
| } | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||
| { | |||
| llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); | |||
| } | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); | |||
| candidates.ApplyGrammar(ctx, grammar); | |||
| } | |||
| /// <summary> | |||
| @@ -88,10 +28,10 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| [Obsolete("use LLamaTokenDataArray Softmax method")] | |||
| public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_softmax(ctx, ref st); | |||
| candidates.Softmax(ctx); | |||
| } | |||
| /// <summary> | |||
| @@ -101,10 +41,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="k"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TopK method")] | |||
| public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); | |||
| candidates.TopK(ctx, k, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -114,10 +54,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TopP method")] | |||
| public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); | |||
| candidates.TopP(ctx, p, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -127,10 +67,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="z"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TailFree method")] | |||
| public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); | |||
| candidates.TailFree(ctx, z, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -140,10 +80,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray LocallyTypical method")] | |||
| public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | |||
| candidates.LocallyTypical(ctx, p, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -153,10 +93,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="temp"></param> | |||
| [Obsolete("use LLamaTokenDataArray Temperature() method")] | |||
| public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_temperature(ctx, ref st, temp); | |||
| candidates.Temperature(ctx, temp); | |||
| } | |||
| /// <summary> | |||
| @@ -169,10 +109,10 @@ namespace LLama.Native | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] | |||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); | |||
| return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); | |||
| } | |||
| /// <summary> | |||
| @@ -184,10 +124,10 @@ namespace LLama.Native | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] | |||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); | |||
| return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); | |||
| } | |||
| /// <summary> | |||
| @@ -196,10 +136,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] | |||
| public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_greedy(ctx, ref st); | |||
| return candidates.SampleTokenGreedy(ctx); | |||
| } | |||
| /// <summary> | |||
| @@ -208,10 +148,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleToken() method")] | |||
| public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token(ctx, ref st); | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| } | |||
| } | |||