using System; using System.Buffers; using System.Runtime.InteropServices; using llama_token = System.Int32; namespace LLama.Native { /// /// Contains an array of LLamaTokenData, potentially sorted. /// public struct LLamaTokenDataArray { /// /// The LLamaTokenData /// public readonly Memory data; /// /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_. /// public bool sorted; /// /// Create a new LLamaTokenDataArray /// /// /// public LLamaTokenDataArray(Memory tokens, bool isSorted = false) { data = tokens; sorted = isSorted; } /// /// Create a new LLamaTokenDataArray, copying the data from the given logits /// /// /// public static LLamaTokenDataArray Create(ReadOnlySpan logits) { var candidates = new LLamaTokenData[logits.Length]; for (var token_id = 0; token_id < logits.Length; token_id++) candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); return new LLamaTokenDataArray(candidates); } /// /// Overwrite the logit values for all given tokens /// /// tuples of token and logit value to overwrite public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) { if (values.Length == 0) return; var dataSpan = data.Span; foreach (var (token, value) in values) { for (var i = 0; i < data.Length; i++) { if (dataSpan[i].id == token) { dataSpan[i].logit = value; break; } } } sorted = false; } #region sampling /// /// Apply grammar rules to candidate tokens /// /// /// public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) { if (grammar == null) return; using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); sorted = st.sorted; } } /// /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// /// /// Number of tokens to keep /// Minimum number to keep public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_top_k(context, ref st, k, minKeep); sorted = st.sorted; } } /// /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// /// /// /// public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_top_p(context, ref st, p, minKeep); sorted = st.sorted; } } /// /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 /// /// /// All tokens with probability greater than this will be kept /// public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_min_p(context, ref st, p, minKeep); sorted = st.sorted; } } /// /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. /// /// /// /// public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); sorted = st.sorted; } } /// /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// /// /// /// public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_typical(context, ref st, p, min_keep); sorted = st.sorted; } } /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// /// /// /// /// /// public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) { fixed (int* last_tokens_handle = last_tokens) { NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); sorted = st.sorted; } } } } /// /// Sample with temperature. /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual /// /// /// public void Temperature(SafeLLamaContextHandle context, float temp) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_temp(context, ref st, temp); sorted = st.sorted; } } /// /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// /// public void Softmax(SafeLLamaContextHandle context) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_softmax(context, ref st); sorted = st.sorted; } } /// /// Randomly selects a token from the candidates based on their probabilities. /// /// /// public int SampleToken(SafeLLamaContextHandle context) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token(context, ref st); sorted = st.sorted; return token; } } /// /// Selects the token with the highest probability. /// /// /// public int SampleTokenGreedy(SafeLLamaContextHandle context) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_greedy(context, ref st); sorted = st.sorted; return token; } } /// /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// /// /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); sorted = st.sorted; return token; } } /// /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// /// /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); sorted = st.sorted; return token; } } #endregion } /// /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. /// [StructLayout(LayoutKind.Sequential)] public struct LLamaTokenDataArrayNative { /// /// A pointer to an array of LlamaTokenData /// /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use public IntPtr data; /// /// Number of LLamaTokenData in the array /// public ulong size; /// /// Indicates if the items in the array are sorted /// public bool sorted { get => Convert.ToBoolean(_sorted); set => _sorted = Convert.ToSByte(value); } private sbyte _sorted; /// /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray /// /// Data source /// Created native array /// A memory handle, pinning the data in place until disposed public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native) { var handle = array.data.Pin(); unsafe { native = new LLamaTokenDataArrayNative { data = new IntPtr(handle.Pointer), size = (ulong)array.data.Length, sorted = array.sorted }; } return handle; } } }