- Moved repeated code to convert `LLamaTokenDataArray` into a `LLamaTokenDataArrayNative` into a helper method. - Modified all call sites to dispose the `MemoryHandle` - Saved one copy of the `List<LLamaTokenData>` into a `LLamaTokenData[]` in `LlamaModel`tags/v0.4.2-preview
| @@ -294,14 +294,10 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| var candidates = new List<LLamaTokenData>(); | |||||
| candidates.Capacity = n_vocab; | |||||
| var candidates = new LLamaTokenData[n_vocab]; | |||||
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) | for (llama_token token_id = 0; token_id < n_vocab; token_id++) | ||||
| { | |||||
| candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); | |||||
| } | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); | |||||
| candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); | |||||
| // Apply penalties | // Apply penalties | ||||
| float nl_logit = logits[NativeApi.llama_token_nl()]; | float nl_logit = logits[NativeApi.llama_token_nl()]; | ||||
| @@ -1,32 +1,80 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| /// <summary> | |||||
| /// Contains an array of LLamaTokenData, potentially sorted. | |||||
| /// </summary> | |||||
| public struct LLamaTokenDataArray | public struct LLamaTokenDataArray | ||||
| { | { | ||||
| public Memory<LLamaTokenData> data; | |||||
| public ulong size; | |||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool sorted; | |||||
| /// <summary> | |||||
| /// The LLamaTokenData | |||||
| /// </summary> | |||||
| public readonly Memory<LLamaTokenData> data; | |||||
| public LLamaTokenDataArray(LLamaTokenData[] data, ulong size, bool sorted) | |||||
| /// <summary> | |||||
| /// Indicates if `data` is sorted | |||||
| /// </summary> | |||||
| public readonly bool sorted; | |||||
| /// <summary> | |||||
| /// Create a new LLamaTokenDataArray | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="isSorted"></param> | |||||
| public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false) | |||||
| { | { | ||||
| this.data = data; | |||||
| this.size = size; | |||||
| this.sorted = sorted; | |||||
| data = tokens; | |||||
| sorted = isSorted; | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. | |||||
| /// </summary> | |||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public struct LLamaTokenDataArrayNative | public struct LLamaTokenDataArrayNative | ||||
| { | { | ||||
| /// <summary> | |||||
| /// A pointer to an array of LlamaTokenData | |||||
| /// </summary> | |||||
| /// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks> | |||||
| public IntPtr data; | public IntPtr data; | ||||
| /// <summary> | |||||
| /// Number of LLamaTokenData in the array | |||||
| /// </summary> | |||||
| public ulong size; | public ulong size; | ||||
| /// <summary> | |||||
| /// Indicates if the items in the array are sorted | |||||
| /// </summary> | |||||
| [MarshalAs(UnmanagedType.I1)] | |||||
| public bool sorted; | public bool sorted; | ||||
| /// <summary> | |||||
| /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray | |||||
| /// </summary> | |||||
| /// <param name="array">Data source</param> | |||||
| /// <param name="native">Created native array</param> | |||||
| /// <returns>A memory handle, pinning the data in place until disposed</returns> | |||||
| public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native) | |||||
| { | |||||
| var handle = array.data.Pin(); | |||||
| unsafe | |||||
| { | |||||
| native = new LLamaTokenDataArrayNative | |||||
| { | |||||
| data = new IntPtr(handle.Pointer), | |||||
| size = (ulong)array.data.Length, | |||||
| sorted = array.sorted | |||||
| }; | |||||
| } | |||||
| return handle; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public unsafe partial class NativeApi | public unsafe partial class NativeApi | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -17,7 +16,7 @@ namespace LLama.Native | |||||
| /// <param name="last_tokens_size"></param> | /// <param name="last_tokens_size"></param> | ||||
| /// <param name="penalty"></param> | /// <param name="penalty"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); | |||||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); | |||||
| /// <summary> | /// <summary> | ||||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | ||||
| @@ -29,7 +28,7 @@ namespace LLama.Native | |||||
| /// <param name="alpha_frequency"></param> | /// <param name="alpha_frequency"></param> | ||||
| /// <param name="alpha_presence"></param> | /// <param name="alpha_presence"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | ||||
| @@ -37,7 +36,7 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, IntPtr candidates); | |||||
| public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||||
| /// <summary> | /// <summary> | ||||
| /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | ||||
| @@ -47,7 +46,7 @@ namespace LLama.Native | |||||
| /// <param name="k"></param> | /// <param name="k"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, IntPtr candidates, int k, ulong min_keep); | |||||
| public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, int k, ulong min_keep); | |||||
| /// <summary> | /// <summary> | ||||
| /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | ||||
| @@ -57,7 +56,7 @@ namespace LLama.Native | |||||
| /// <param name="p"></param> | /// <param name="p"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep); | |||||
| public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep); | |||||
| /// <summary> | /// <summary> | ||||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | ||||
| @@ -67,7 +66,7 @@ namespace LLama.Native | |||||
| /// <param name="z"></param> | /// <param name="z"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, IntPtr candidates, float z, ulong min_keep); | |||||
| public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float z, ulong min_keep); | |||||
| /// <summary> | /// <summary> | ||||
| /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | ||||
| @@ -77,10 +76,16 @@ namespace LLama.Native | |||||
| /// <param name="p"></param> | /// <param name="p"></param> | ||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep); | |||||
| public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep); | |||||
| /// <summary> | |||||
| /// Modify logits by temperature | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <param name="temp"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, IntPtr candidates, float temp); | |||||
| public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float temp); | |||||
| /// <summary> | /// <summary> | ||||
| /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | ||||
| @@ -93,7 +98,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu); | |||||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | ||||
| @@ -105,7 +110,7 @@ namespace LLama.Native | |||||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu); | |||||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu); | |||||
| /// <summary> | /// <summary> | ||||
| /// Selects the token with the highest probability. | /// Selects the token with the highest probability. | ||||
| @@ -114,7 +119,7 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, IntPtr candidates); | |||||
| public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||||
| /// <summary> | /// <summary> | ||||
| /// Randomly selects a token from the candidates based on their probabilities. | /// Randomly selects a token from the candidates based on their probabilities. | ||||
| @@ -123,6 +128,6 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, IntPtr candidates); | |||||
| public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| @@ -18,12 +15,8 @@ namespace LLama.Native | |||||
| /// <param name="penalty"></param> | /// <param name="penalty"></param> | ||||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) | public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_repetition_penalty(ctx, new IntPtr(&st), last_tokens, last_tokens_size, penalty); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -37,12 +30,8 @@ namespace LLama.Native | |||||
| /// <param name="alpha_presence"></param> | /// <param name="alpha_presence"></param> | ||||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, new IntPtr(&st), last_tokens, last_tokens_size, alpha_frequency, alpha_presence); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -52,12 +41,8 @@ namespace LLama.Native | |||||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | ||||
| public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_softmax(ctx, new IntPtr(&st)); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_softmax(ctx, ref st); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -69,12 +54,8 @@ namespace LLama.Native | |||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_top_k(ctx, new IntPtr(&st), k, min_keep); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -86,12 +67,8 @@ namespace LLama.Native | |||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_top_p(ctx, new IntPtr(&st), p, min_keep); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -103,12 +80,8 @@ namespace LLama.Native | |||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_tail_free(ctx, new IntPtr(&st), z, min_keep); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -120,22 +93,14 @@ namespace LLama.Native | |||||
| /// <param name="min_keep"></param> | /// <param name="min_keep"></param> | ||||
| public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_typical(ctx, new IntPtr(&st), p, min_keep); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | |||||
| } | } | ||||
| public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| NativeApi.llama_sample_temperature(ctx, new IntPtr(&st), temp); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| NativeApi.llama_sample_temperature(ctx, ref st, temp); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -150,17 +115,11 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| llama_token res; | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| fixed(float* pmu = &mu) | fixed(float* pmu = &mu) | ||||
| { | { | ||||
| res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu); | |||||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu); | |||||
| } | } | ||||
| return res; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -174,17 +133,11 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| llama_token res; | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| fixed (float* pmu = &mu) | fixed (float* pmu = &mu) | ||||
| { | { | ||||
| res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu); | |||||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu); | |||||
| } | } | ||||
| return res; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -195,12 +148,8 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| return NativeApi.llama_sample_token_greedy(ctx, new IntPtr(&st)); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token_greedy(ctx, ref st); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -211,12 +160,8 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | ||||
| { | { | ||||
| var handle = candidates.data.Pin(); | |||||
| var st = new LLamaTokenDataArrayNative(); | |||||
| st.data = new IntPtr(handle.Pointer); | |||||
| st.size = candidates.size; | |||||
| st.sorted = candidates.sorted; | |||||
| return NativeApi.llama_sample_token(ctx, new IntPtr(&st)); | |||||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||||
| return NativeApi.llama_sample_token(ctx, ref st); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -632,14 +632,10 @@ namespace LLama.OldVersion | |||||
| logits[key] += value; | logits[key] += value; | ||||
| } | } | ||||
| var candidates = new List<LLamaTokenData>(); | |||||
| candidates.Capacity = n_vocab; | |||||
| var candidates = new LLamaTokenData[n_vocab]; | |||||
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) | for (llama_token token_id = 0; token_id < n_vocab; token_id++) | ||||
| { | |||||
| candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); | |||||
| } | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); | |||||
| candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); | |||||
| LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); | |||||
| // Apply penalties | // Apply penalties | ||||
| float nl_logit = logits[NativeApi.llama_token_nl()]; | float nl_logit = logits[NativeApi.llama_token_nl()]; | ||||