You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NativeApi.Sampling.cs 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. using System;
  2. using System.Runtime.InteropServices;
  3. namespace LLama.Native
  4. {
  5. public static partial class NativeApi
  6. {
  7. /// <summary>
  8. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  9. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  10. /// </summary>
  11. /// <param name="ctx"></param>
  12. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  13. /// <param name="last_tokens"></param>
  14. /// <param name="last_tokens_size"></param>
  15. /// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param>
  16. /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
  17. /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
  18. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  19. public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
  20. ref LLamaTokenDataArrayNative candidates,
  21. LLamaToken* last_tokens, ulong last_tokens_size,
  22. float penalty_repeat,
  23. float penalty_freq,
  24. float penalty_present);
  25. /// <summary>
  26. /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
  27. /// </summary>
  28. /// <param name="ctx"></param>
  29. /// <param name="logits">Logits extracted from the original generation context.</param>
  30. /// <param name="logits_guidance">Logits extracted from a separate context from the same model.
  31. /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
  32. /// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
  33. public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<float> logits_guidance, float scale)
  34. {
  35. if (logits == null)
  36. throw new ArgumentNullException(nameof(logits));
  37. if (logits_guidance == null)
  38. throw new ArgumentNullException(nameof(logits_guidance));
  39. if (logits.Length != ctx.VocabCount)
  40. throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits));
  41. if (logits_guidance.Length != ctx.VocabCount)
  42. throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance));
  43. unsafe
  44. {
  45. fixed (float* logitsPtr = logits)
  46. fixed (float* logitsGuidancePtr = logits_guidance)
  47. llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale);
  48. }
  49. }
  50. /// <summary>
  51. /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
  52. /// </summary>
  53. /// <param name="ctx"></param>
  54. /// <param name="logits">Logits extracted from the original generation context.</param>
  55. /// <param name="logits_guidance">Logits extracted from a separate context from the same model.
  56. /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
  57. /// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
  58. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  59. public static extern unsafe void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, float* logits, float* logits_guidance, float scale);
  60. /// <summary>
  61. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  62. /// </summary>
  63. /// <param name="ctx"></param>
  64. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  65. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  66. public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
  67. /// <summary>
  68. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  69. /// </summary>
  70. /// <param name="ctx"></param>
  71. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  72. /// <param name="k"></param>
  73. /// <param name="min_keep"></param>
  74. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  75. public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, int k, ulong min_keep);
  76. /// <summary>
  77. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  78. /// </summary>
  79. /// <param name="ctx"></param>
  80. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  81. /// <param name="p"></param>
  82. /// <param name="min_keep"></param>
  83. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  84. public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
  85. /// <summary>
  86. /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
  87. /// </summary>
  88. /// <param name="ctx"></param>
  89. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  90. /// <param name="p"></param>
  91. /// <param name="min_keep"></param>
  92. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  93. public static extern void llama_sample_min_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
  94. /// <summary>
  95. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  96. /// </summary>
  97. /// <param name="ctx"></param>
  98. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  99. /// <param name="z"></param>
  100. /// <param name="min_keep"></param>
  101. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  102. public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float z, ulong min_keep);
  103. /// <summary>
  104. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  105. /// </summary>
  106. /// <param name="ctx"></param>
  107. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  108. /// <param name="p"></param>
  109. /// <param name="min_keep"></param>
  110. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  111. public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
  112. /// <summary>
  113. /// Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
  114. /// </summary>
  115. /// <param name="ctx"></param>
  116. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  117. /// <param name="min_temp"></param>
  118. /// <param name="max_temp"></param>
  119. /// <param name="exponent_val"></param>
  120. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  121. public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float min_temp, float max_temp, float exponent_val);
  122. /// <summary>
  123. /// Modify logits by temperature
  124. /// </summary>
  125. /// <param name="ctx"></param>
  126. /// <param name="candidates"></param>
  127. /// <param name="temp"></param>
  128. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  129. public static extern void llama_sample_temp(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float temp);
  130. /// <summary>
  131. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  132. /// </summary>
  133. /// <param name="ctx"></param>
  134. /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  135. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  136. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  137. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  138. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  139. /// <returns></returns>
  140. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  141. public static extern LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu);
  142. /// <summary>
  143. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  144. /// </summary>
  145. /// <param name="ctx"></param>
  146. /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  147. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  148. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  149. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  150. /// <returns></returns>
  151. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  152. public static extern LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu);
  153. /// <summary>
  154. /// Selects the token with the highest probability.
  155. /// </summary>
  156. /// <param name="ctx"></param>
  157. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  158. /// <returns></returns>
  159. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  160. public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
  161. /// <summary>
  162. /// Randomly selects a token from the candidates based on their probabilities.
  163. /// </summary>
  164. /// <param name="ctx"></param>
  165. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  166. /// <returns></returns>
  167. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
  168. public static extern LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
  169. }
  170. }