| @@ -6,6 +6,10 @@ using LLama.Native; | |||||
| namespace LLama.Examples.NewVersion; | namespace LLama.Examples.NewVersion; | ||||
| /// <summary> | |||||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||||
| /// </summary> | |||||
| /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks> | |||||
| public class BatchedDecoding | public class BatchedDecoding | ||||
| { | { | ||||
| private const int n_parallel = 8; | private const int n_parallel = 8; | ||||
| @@ -116,12 +120,11 @@ public class BatchedDecoding | |||||
| { | { | ||||
| candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | ||||
| } | } | ||||
| using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native); | |||||
| candidates_native.TopK(context.NativeHandle, top_k); | |||||
| candidates_native.TopP(context.NativeHandle, top_p); | |||||
| candidates_native.Temperature(context.NativeHandle, temp); | |||||
| var new_token_id = candidates_native.SampleToken(context.NativeHandle); | |||||
| candidates.TopK(context.NativeHandle, top_k); | |||||
| candidates.TopP(context.NativeHandle, top_p); | |||||
| candidates.Temperature(context.NativeHandle, temp); | |||||
| var new_token_id = candidates.SampleToken(context.NativeHandle); | |||||
| if (new_token_id == eos || new_token_id == nl) | if (new_token_id == eos || new_token_id == nl) | ||||
| { | { | ||||
| @@ -22,7 +22,7 @@ | |||||
| Console.WriteLine("12: Semantic Kernel Chat."); | Console.WriteLine("12: Semantic Kernel Chat."); | ||||
| Console.WriteLine("13: Semantic Kernel Memory."); | Console.WriteLine("13: Semantic Kernel Memory."); | ||||
| Console.WriteLine("14: Coding Assistant."); | Console.WriteLine("14: Coding Assistant."); | ||||
| Console.WriteLine("15: Batch Decoding Benchmark."); | |||||
| Console.WriteLine("15: Batch Decoding."); | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| @@ -235,13 +235,13 @@ namespace LLama | |||||
| if (grammar != null) | if (grammar != null) | ||||
| { | { | ||||
| SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); | |||||
| candidates.ApplyGrammar(NativeHandle, grammar); | |||||
| } | } | ||||
| if (temperature <= 0) | if (temperature <= 0) | ||||
| { | { | ||||
| // Greedy sampling | // Greedy sampling | ||||
| id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); | |||||
| id = candidates.SampleTokenGreedy(NativeHandle); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -250,32 +250,28 @@ namespace LLama | |||||
| if (mirostat == MirostatType.Mirostat) | if (mirostat == MirostatType.Mirostat) | ||||
| { | { | ||||
| const int mirostat_m = 100; | const int mirostat_m = 100; | ||||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||||
| id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||||
| candidates.Temperature(NativeHandle, temperature); | |||||
| id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||||
| } | } | ||||
| else if (mirostat == MirostatType.Mirostat2) | else if (mirostat == MirostatType.Mirostat2) | ||||
| { | { | ||||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||||
| id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); | |||||
| candidates.Temperature(NativeHandle, temperature); | |||||
| id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // Temperature sampling | |||||
| SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); | |||||
| SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); | |||||
| SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); | |||||
| SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); | |||||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||||
| id = SamplingApi.llama_sample_token(NativeHandle, candidates); | |||||
| candidates.TopK(NativeHandle, topK); | |||||
| candidates.TailFree(NativeHandle, tfsZ); | |||||
| candidates.LocallyTypical(NativeHandle, typicalP); | |||||
| candidates.TopP(NativeHandle, topP); | |||||
| candidates.Temperature(NativeHandle, temperature); | |||||
| id = candidates.SampleToken(NativeHandle); | |||||
| } | } | ||||
| } | } | ||||
| mirostat_mu = mu; | mirostat_mu = mu; | ||||
| } | } | ||||
| if (grammar != null) | |||||
| { | |||||
| NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); | |||||
| } | |||||
| grammar?.AcceptToken(NativeHandle, id); | |||||
| return id; | return id; | ||||
| } | } | ||||
| @@ -305,7 +301,7 @@ namespace LLama | |||||
| } | } | ||||
| // Save the newline logit value | // Save the newline logit value | ||||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||||
| var nl_logit = logits[nl_token]; | var nl_logit = logits[nl_token]; | ||||
| // Convert logits into token candidates | // Convert logits into token candidates | ||||
| @@ -316,8 +312,7 @@ namespace LLama | |||||
| var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | ||||
| // Apply penalties to candidates | // Apply penalties to candidates | ||||
| SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); | |||||
| SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||||
| candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); | |||||
| // Restore newline token logit value if necessary | // Restore newline token logit value if necessary | ||||
| if (!penalizeNL) | if (!penalizeNL) | ||||
| @@ -45,6 +45,199 @@ namespace LLama.Native | |||||
| return new LLamaTokenDataArray(candidates); | return new LLamaTokenDataArray(candidates); | ||||
| } | } | ||||
| #region sampling | |||||
| /// <summary> | |||||
| /// Apply grammar rules to candidate tokens | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="grammar"></param> | |||||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="k">Number of tokens to keep</param> | |||||
| /// <param name="minKeep">Minimum number to keep</param> | |||||
| public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_top_k(context, ref st, k, minKeep); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="p"></param> | |||||
| /// <param name="minKeep"></param> | |||||
| public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_top_p(context, ref st, p, minKeep); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="z"></param> | |||||
| /// <param name="min_keep"></param> | |||||
| public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="p"></param> | |||||
| /// <param name="min_keep"></param> | |||||
| public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_typical(context, ref st, p, min_keep); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="penalty_repeat"></param> | |||||
| /// <param name="penalty_freq"></param> | |||||
| /// <param name="penalty_present"></param> | |||||
| public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| using (var last_tokens_handle = last_tokens.Pin()) | |||||
| { | |||||
| NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Sample with temperature. | |||||
| /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="temp"></param> | |||||
| public void Temperature(SafeLLamaContextHandle context, float temp) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_temperature(context, ref st, temp); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| public void Softmax(SafeLLamaContextHandle context) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| NativeApi.llama_sample_softmax(context, ref st); | |||||
| sorted = st.sorted; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Randomly selects a token from the candidates based on their probabilities. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <returns></returns> | |||||
| public int SampleToken(SafeLLamaContextHandle context) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| var token = NativeApi.llama_sample_token(context, ref st); | |||||
| sorted = st.sorted; | |||||
| return token; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Selects the token with the highest probability. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <returns></returns> | |||||
| public int SampleTokenGreedy(SafeLLamaContextHandle context) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| var token = NativeApi.llama_sample_token_greedy(context, ref st); | |||||
| sorted = st.sorted; | |||||
| return token; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||||
| /// <returns></returns> | |||||
| public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); | |||||
| sorted = st.sorted; | |||||
| return token; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||||
| /// <returns></returns> | |||||
| public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) | |||||
| { | |||||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||||
| { | |||||
| var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); | |||||
| sorted = st.sorted; | |||||
| return token; | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -96,48 +289,5 @@ namespace LLama.Native | |||||
| return handle; | return handle; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Perform TopK sampling, sorting the data and reducing the size to k | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="k">Number of tokens to keep</param> | |||||
| /// <param name="minKeep">Minimum number to keep</param> | |||||
| public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) | |||||
| { | |||||
| NativeApi.llama_sample_top_k(context, ref this, k, minKeep); | |||||
| } | |||||
| /// <summary> | |||||
| /// Perform top p sampling, sorting the data and keeping only logits more likely than p | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="p"></param> | |||||
| /// <param name="minKeep"></param> | |||||
| public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||||
| { | |||||
| NativeApi.llama_sample_top_p(context, ref this, p, minKeep); | |||||
| } | |||||
| /// <summary> | |||||
| /// Apply temperature to logits | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <param name="temp"></param> | |||||
| public void Temperature(SafeLLamaContextHandle context, float temp) | |||||
| { | |||||
| NativeApi.llama_sample_temperature(context, ref this, temp); | |||||
| } | |||||
| /// <summary> | |||||
| /// Sample a token from the set of possible tokens | |||||
| /// </summary> | |||||
| /// <param name="context"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="NotImplementedException"></exception> | |||||
| public int SampleToken(SafeLLamaContextHandle context) | |||||
| { | |||||
| return NativeApi.llama_sample_token(context, ref this); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,26 +9,22 @@ namespace LLama.Native | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | ||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="last_tokens_size"></param> | |||||
| /// <param name="penalty"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); | |||||
| /// <summary> | |||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <param name="last_tokens"></param> | /// <param name="last_tokens"></param> | ||||
| /// <param name="last_tokens_size"></param> | /// <param name="last_tokens_size"></param> | ||||
| /// <param name="alpha_frequency"></param> | |||||
| /// <param name="alpha_presence"></param> | |||||
| /// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param> | |||||
| /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||||
| /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||||
| public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||||
| ref LLamaTokenDataArrayNative candidates, | |||||
| llama_token* last_tokens, ulong last_tokens_size, | |||||
| float penalty_repeat, | |||||
| float penalty_freq, | |||||
| float penalty_present); | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | ||||
| @@ -102,5 +102,15 @@ namespace LLama.Native | |||||
| return new(grammar_ptr); | return new(grammar_ptr); | ||||
| } | } | ||||
| #endregion | #endregion | ||||
| /// <summary> | |||||
| /// Accepts the sampled token into the grammar | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="token"></param> | |||||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||||
| { | |||||
| NativeApi.llama_grammar_accept_token(ctx, this, token); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Direct translation of the llama.cpp sampling API | /// Direct translation of the llama.cpp sampling API | ||||
| /// </summary> | /// </summary> | ||||
| public unsafe class SamplingApi | |||||
| public class SamplingApi | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Apply grammar rules to candidate tokens | /// Apply grammar rules to candidate tokens | ||||
| @@ -17,70 +17,10 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates"></param> | /// <param name="candidates"></param> | ||||
| /// <param name="grammar"></param> | /// <param name="grammar"></param> | ||||
| [Obsolete("use LLamaTokenDataArray ApplyGrammar method")] | |||||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||||
| } | |||||
| /// <summary> | |||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="last_tokens_size"></param> | |||||
| /// <param name="penalty"></param> | |||||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty) | |||||
| { | |||||
| llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); | |||||
| } | |||||
| /// <summary> | |||||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="penalty"></param> | |||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty) | |||||
| { | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| using var last_tokens_handle = last_tokens.Pin(); | |||||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); | |||||
| } | |||||
| /// <summary> | |||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="last_tokens_size"></param> | |||||
| /// <param name="alpha_frequency"></param> | |||||
| /// <param name="alpha_presence"></param> | |||||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||||
| { | |||||
| llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); | |||||
| } | |||||
| /// <summary> | |||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||||
| /// <param name="last_tokens"></param> | |||||
| /// <param name="alpha_frequency"></param> | |||||
| /// <param name="alpha_presence"></param> | |||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence) | |||||
| { | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| using var last_tokens_handle = last_tokens.Pin(); | |||||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); | |||||
| candidates.ApplyGrammar(ctx, grammar); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -88,10 +28,10 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| [Obsolete("use LLamaTokenDataArray Softmax method")] | |||||
| public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_softmax(ctx, ref st); | |||||
| candidates.Softmax(ctx); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -101,10 +41,10 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <param name="k"></param> | /// <param name="k"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [Obsolete("use LLamaTokenDataArray TopK method")] | |||||
| public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); | |||||
| candidates.TopK(ctx, k, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -114,10 +54,10 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <param name="p"></param> | /// <param name="p"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [Obsolete("use LLamaTokenDataArray TopP method")] | |||||
| public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); | |||||
| candidates.TopP(ctx, p, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -127,10 +67,10 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <param name="z"></param> | /// <param name="z"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [Obsolete("use LLamaTokenDataArray TailFree method")] | |||||
| public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); | |||||
| candidates.TailFree(ctx, z, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -140,10 +80,10 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <param name="p"></param> | /// <param name="p"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [Obsolete("use LLamaTokenDataArray LocallyTypical method")] | |||||
| public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | |||||
| candidates.LocallyTypical(ctx, p, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -153,10 +93,10 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates"></param> | /// <param name="candidates"></param> | ||||
| /// <param name="temp"></param> | /// <param name="temp"></param> | ||||
| [Obsolete("use LLamaTokenDataArray Temperature() method")] | |||||
| public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_temperature(ctx, ref st, temp); | |||||
| candidates.Temperature(ctx, temp); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -169,10 +109,10 @@ namespace LLama.Native | |||||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | ||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] | |||||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); | |||||
| return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -184,10 +124,10 @@ namespace LLama.Native | |||||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | ||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] | |||||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); | |||||
| return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -196,10 +136,10 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] | |||||
| public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token_greedy(ctx, ref st); | |||||
| return candidates.SampleTokenGreedy(ctx); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -208,10 +148,10 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [Obsolete("use LLamaTokenDataArray SampleToken() method")] | |||||
| public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token(ctx, ref st); | |||||
| return candidates.SampleToken(ctx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||