You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SamplingApi.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. namespace LLama.Native
  6. {
  7. using llama_token = Int32;
  8. internal unsafe class SamplingApi
  9. {
  10. /// <summary>
  11. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  12. /// </summary>
  13. /// <param name="ctx"></param>
  14. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  15. /// <param name="last_tokens"></param>
  16. /// <param name="last_tokens_size"></param>
  17. /// <param name="penalty"></param>
  18. public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty)
  19. {
  20. var handle = candidates.data.Pin();
  21. var st = new LLamaTokenDataArrayNative();
  22. st.data = new IntPtr(handle.Pointer);
  23. st.size = candidates.size;
  24. st.sorted = candidates.sorted;
  25. NativeApi.llama_sample_repetition_penalty(ctx, new IntPtr(&st), last_tokens, last_tokens_size, penalty);
  26. }
  27. /// <summary>
  28. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  29. /// </summary>
  30. /// <param name="ctx"></param>
  31. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  32. /// <param name="last_tokens"></param>
  33. /// <param name="last_tokens_size"></param>
  34. /// <param name="alpha_frequency"></param>
  35. /// <param name="alpha_presence"></param>
  36. public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
  37. {
  38. var handle = candidates.data.Pin();
  39. var st = new LLamaTokenDataArrayNative();
  40. st.data = new IntPtr(handle.Pointer);
  41. st.size = candidates.size;
  42. st.sorted = candidates.sorted;
  43. NativeApi.llama_sample_frequency_and_presence_penalties(ctx, new IntPtr(&st), last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
  44. }
  45. /// <summary>
  46. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  47. /// </summary>
  48. /// <param name="ctx"></param>
  49. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  50. public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  51. {
  52. var handle = candidates.data.Pin();
  53. var st = new LLamaTokenDataArrayNative();
  54. st.data = new IntPtr(handle.Pointer);
  55. st.size = candidates.size;
  56. st.sorted = candidates.sorted;
  57. NativeApi.llama_sample_softmax(ctx, new IntPtr(&st));
  58. }
  59. /// <summary>
  60. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  61. /// </summary>
  62. /// <param name="ctx"></param>
  63. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  64. /// <param name="k"></param>
  65. /// <param name="min_keep"></param>
  66. public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
  67. {
  68. var handle = candidates.data.Pin();
  69. var st = new LLamaTokenDataArrayNative();
  70. st.data = new IntPtr(handle.Pointer);
  71. st.size = candidates.size;
  72. st.sorted = candidates.sorted;
  73. NativeApi.llama_sample_top_k(ctx, new IntPtr(&st), k, min_keep);
  74. }
  75. /// <summary>
  76. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  77. /// </summary>
  78. /// <param name="ctx"></param>
  79. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  80. /// <param name="p"></param>
  81. /// <param name="min_keep"></param>
  82. public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  83. {
  84. var handle = candidates.data.Pin();
  85. var st = new LLamaTokenDataArrayNative();
  86. st.data = new IntPtr(handle.Pointer);
  87. st.size = candidates.size;
  88. st.sorted = candidates.sorted;
  89. NativeApi.llama_sample_top_p(ctx, new IntPtr(&st), p, min_keep);
  90. }
  91. /// <summary>
  92. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  93. /// </summary>
  94. /// <param name="ctx"></param>
  95. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  96. /// <param name="z"></param>
  97. /// <param name="min_keep"></param>
  98. public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
  99. {
  100. var handle = candidates.data.Pin();
  101. var st = new LLamaTokenDataArrayNative();
  102. st.data = new IntPtr(handle.Pointer);
  103. st.size = candidates.size;
  104. st.sorted = candidates.sorted;
  105. NativeApi.llama_sample_tail_free(ctx, new IntPtr(&st), z, min_keep);
  106. }
  107. /// <summary>
  108. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  109. /// </summary>
  110. /// <param name="ctx"></param>
  111. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  112. /// <param name="p"></param>
  113. /// <param name="min_keep"></param>
  114. public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  115. {
  116. var handle = candidates.data.Pin();
  117. var st = new LLamaTokenDataArrayNative();
  118. st.data = new IntPtr(handle.Pointer);
  119. st.size = candidates.size;
  120. st.sorted = candidates.sorted;
  121. NativeApi.llama_sample_typical(ctx, new IntPtr(&st), p, min_keep);
  122. }
  123. public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
  124. {
  125. var handle = candidates.data.Pin();
  126. var st = new LLamaTokenDataArrayNative();
  127. st.data = new IntPtr(handle.Pointer);
  128. st.size = candidates.size;
  129. st.sorted = candidates.sorted;
  130. NativeApi.llama_sample_temperature(ctx, new IntPtr(&st), temp);
  131. }
  132. /// <summary>
  133. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  134. /// </summary>
  135. /// <param name="ctx"></param>
  136. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  137. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  138. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  139. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  140. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  141. /// <returns></returns>
  142. public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
  143. {
  144. var handle = candidates.data.Pin();
  145. var st = new LLamaTokenDataArrayNative();
  146. st.data = new IntPtr(handle.Pointer);
  147. st.size = candidates.size;
  148. st.sorted = candidates.sorted;
  149. llama_token res;
  150. fixed(float* pmu = &mu)
  151. {
  152. res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu);
  153. }
  154. return res;
  155. }
  156. /// <summary>
  157. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  158. /// </summary>
  159. /// <param name="ctx"></param>
  160. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  161. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  162. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  163. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  164. /// <returns></returns>
  165. public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
  166. {
  167. var handle = candidates.data.Pin();
  168. var st = new LLamaTokenDataArrayNative();
  169. st.data = new IntPtr(handle.Pointer);
  170. st.size = candidates.size;
  171. st.sorted = candidates.sorted;
  172. llama_token res;
  173. fixed (float* pmu = &mu)
  174. {
  175. res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu);
  176. }
  177. return res;
  178. }
  179. /// <summary>
  180. /// Selects the token with the highest probability.
  181. /// </summary>
  182. /// <param name="ctx"></param>
  183. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  184. /// <returns></returns>
  185. public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  186. {
  187. var handle = candidates.data.Pin();
  188. var st = new LLamaTokenDataArrayNative();
  189. st.data = new IntPtr(handle.Pointer);
  190. st.size = candidates.size;
  191. st.sorted = candidates.sorted;
  192. return NativeApi.llama_sample_token_greedy(ctx, new IntPtr(&st));
  193. }
  194. /// <summary>
  195. /// Randomly selects a token from the candidates based on their probabilities.
  196. /// </summary>
  197. /// <param name="ctx"></param>
  198. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  199. /// <returns></returns>
  200. public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  201. {
  202. var handle = candidates.data.Pin();
  203. var st = new LLamaTokenDataArrayNative();
  204. st.data = new IntPtr(handle.Pointer);
  205. st.size = candidates.size;
  206. st.sorted = candidates.sorted;
  207. return NativeApi.llama_sample_token(ctx, new IntPtr(&st));
  208. }
  209. }
  210. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。