- Modified `llama_sample_token_mirostat` and `llama_sample_token_mirostat_v2` to take `ref float` instead of as a `float*`. Less pointers is always good. - Modified `llama_sample_repetition_penalty` and `llama_sample_frequency_and_presence_penalties` to take pointers instead of arrays. This allows the use non non allocating types (e.g. Span) instead of arrays - Modified higher level API to accept `Memory<int>` instead of `int[]`, which can be used to reduce allocations at call sitestags/v0.5.1
| @@ -26,7 +26,7 @@ namespace LLama.Native | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); | |||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| @@ -38,7 +38,7 @@ namespace LLama.Native | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||
| /// <summary> | |||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | |||
| @@ -118,7 +118,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu); | |||
| public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); | |||
| /// <summary> | |||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| @@ -130,7 +130,7 @@ namespace LLama.Native | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu); | |||
| public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); | |||
| /// <summary> | |||
| /// Selects the token with the highest probability. | |||
| @@ -25,10 +25,12 @@ namespace LLama.Native | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty); | |||
| } | |||
| /// <summary> | |||
| @@ -40,10 +42,12 @@ namespace LLama.Native | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence); | |||
| } | |||
| /// <summary> | |||
| @@ -128,10 +132,7 @@ namespace LLama.Native | |||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| fixed(float* pmu = &mu) | |||
| { | |||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu); | |||
| } | |||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); | |||
| } | |||
| /// <summary> | |||
| @@ -146,10 +147,7 @@ namespace LLama.Native | |||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| fixed (float* pmu = &mu) | |||
| { | |||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu); | |||
| } | |||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); | |||
| } | |||
| /// <summary> | |||