You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SamplingApi.cs 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. using System;
  2. #pragma warning disable IDE1006 // Naming Styles
  3. namespace LLama.Native
  4. {
  5. using llama_token = Int32;
  6. /// <summary>
  7. /// Direct translation of the llama.cpp sampling API
  8. /// </summary>
  9. public unsafe class SamplingApi
  10. {
  11. /// <summary>
  12. /// Apply grammar rules to candidate tokens
  13. /// </summary>
  14. /// <param name="ctx"></param>
  15. /// <param name="candidates"></param>
  16. /// <param name="grammar"></param>
  17. public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
  18. {
  19. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  20. NativeApi.llama_sample_grammar(ctx, ref st, grammar);
  21. }
  22. /// <summary>
  23. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  24. /// </summary>
  25. /// <param name="ctx"></param>
  26. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  27. /// <param name="last_tokens"></param>
  28. /// <param name="last_tokens_size"></param>
  29. /// <param name="penalty"></param>
  30. [Obsolete("last_tokens_size parameter is no longer needed")]
  31. public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
  32. {
  33. llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
  34. }
  35. /// <summary>
  36. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  37. /// </summary>
  38. /// <param name="ctx"></param>
  39. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  40. /// <param name="last_tokens"></param>
  41. /// <param name="penalty"></param>
  42. public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
  43. {
  44. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  45. using var last_tokens_handle = last_tokens.Pin();
  46. NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
  47. }
  48. /// <summary>
  49. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  50. /// </summary>
  51. /// <param name="ctx"></param>
  52. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  53. /// <param name="last_tokens"></param>
  54. /// <param name="last_tokens_size"></param>
  55. /// <param name="alpha_frequency"></param>
  56. /// <param name="alpha_presence"></param>
  57. [Obsolete("last_tokens_size parameter is no longer needed")]
  58. public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
  59. {
  60. llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
  61. }
  62. /// <summary>
  63. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  64. /// </summary>
  65. /// <param name="ctx"></param>
  66. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  67. /// <param name="last_tokens"></param>
  68. /// <param name="alpha_frequency"></param>
  69. /// <param name="alpha_presence"></param>
  70. public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
  71. {
  72. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  73. using var last_tokens_handle = last_tokens.Pin();
  74. NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
  75. }
  76. /// <summary>
  77. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  78. /// </summary>
  79. /// <param name="ctx"></param>
  80. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  81. public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  82. {
  83. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  84. NativeApi.llama_sample_softmax(ctx, ref st);
  85. }
  86. /// <summary>
  87. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  88. /// </summary>
  89. /// <param name="ctx"></param>
  90. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  91. /// <param name="k"></param>
  92. /// <param name="min_keep"></param>
  93. public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
  94. {
  95. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  96. NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
  97. }
  98. /// <summary>
  99. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  100. /// </summary>
  101. /// <param name="ctx"></param>
  102. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  103. /// <param name="p"></param>
  104. /// <param name="min_keep"></param>
  105. public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  106. {
  107. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  108. NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
  109. }
  110. /// <summary>
  111. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  112. /// </summary>
  113. /// <param name="ctx"></param>
  114. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  115. /// <param name="z"></param>
  116. /// <param name="min_keep"></param>
  117. public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
  118. {
  119. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  120. NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
  121. }
  122. /// <summary>
  123. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  124. /// </summary>
  125. /// <param name="ctx"></param>
  126. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  127. /// <param name="p"></param>
  128. /// <param name="min_keep"></param>
  129. public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  130. {
  131. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  132. NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
  133. }
  134. /// <summary>
  135. /// Sample with temperature.
  136. /// As temperature increases, the prediction becomes diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
  137. /// </summary>
  138. /// <param name="ctx"></param>
  139. /// <param name="candidates"></param>
  140. /// <param name="temp"></param>
  141. public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
  142. {
  143. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  144. NativeApi.llama_sample_temperature(ctx, ref st, temp);
  145. }
  146. /// <summary>
  147. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  148. /// </summary>
  149. /// <param name="ctx"></param>
  150. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  151. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  152. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  153. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  154. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  155. /// <returns></returns>
  156. public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
  157. {
  158. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  159. return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
  160. }
  161. /// <summary>
  162. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  163. /// </summary>
  164. /// <param name="ctx"></param>
  165. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  166. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  167. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  168. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  169. /// <returns></returns>
  170. public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
  171. {
  172. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  173. return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
  174. }
  175. /// <summary>
  176. /// Selects the token with the highest probability.
  177. /// </summary>
  178. /// <param name="ctx"></param>
  179. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  180. /// <returns></returns>
  181. public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  182. {
  183. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  184. return NativeApi.llama_sample_token_greedy(ctx, ref st);
  185. }
  186. /// <summary>
  187. /// Randomly selects a token from the candidates based on their probabilities.
  188. /// </summary>
  189. /// <param name="ctx"></param>
  190. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  191. /// <returns></returns>
  192. public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  193. {
  194. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  195. return NativeApi.llama_sample_token(ctx, ref st);
  196. }
  197. }
  198. }