You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTokenDataArray.cs 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. using System;
  2. using System.Buffers;
  3. using System.Runtime.InteropServices;
  4. using llama_token = System.Int32;
  5. namespace LLama.Native
  6. {
  7. /// <summary>
  8. /// Contains an array of LLamaTokenData, potentially sorted.
  9. /// </summary>
  10. public struct LLamaTokenDataArray
  11. {
  12. /// <summary>
  13. /// The LLamaTokenData
  14. /// </summary>
  15. public readonly Memory<LLamaTokenData> data;
  16. /// <summary>
  17. /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
  18. /// </summary>
  19. public bool sorted;
  20. /// <summary>
  21. /// Create a new LLamaTokenDataArray
  22. /// </summary>
  23. /// <param name="tokens"></param>
  24. /// <param name="isSorted"></param>
  25. public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
  26. {
  27. data = tokens;
  28. sorted = isSorted;
  29. }
  30. /// <summary>
  31. /// Create a new LLamaTokenDataArray, copying the data from the given logits
  32. /// </summary>
  33. /// <param name="logits"></param>
  34. /// <returns></returns>
  35. public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
  36. {
  37. var candidates = new LLamaTokenData[logits.Length];
  38. for (var token_id = 0; token_id < logits.Length; token_id++)
  39. candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
  40. return new LLamaTokenDataArray(candidates);
  41. }
  42. #region sampling
  43. /// <summary>
  44. /// Apply grammar rules to candidate tokens
  45. /// </summary>
  46. /// <param name="ctx"></param>
  47. /// <param name="grammar"></param>
  48. public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
  49. {
  50. using (LLamaTokenDataArrayNative.Create(this, out var st))
  51. {
  52. NativeApi.llama_sample_grammar(ctx, ref st, grammar);
  53. sorted = st.sorted;
  54. }
  55. }
  56. /// <summary>
  57. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  58. /// </summary>
  59. /// <param name="context"></param>
  60. /// <param name="k">Number of tokens to keep</param>
  61. /// <param name="minKeep">Minimum number to keep</param>
  62. public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
  63. {
  64. using (LLamaTokenDataArrayNative.Create(this, out var st))
  65. {
  66. NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
  67. sorted = st.sorted;
  68. }
  69. }
  70. /// <summary>
  71. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  72. /// </summary>
  73. /// <param name="context"></param>
  74. /// <param name="p"></param>
  75. /// <param name="minKeep"></param>
  76. public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
  77. {
  78. using (LLamaTokenDataArrayNative.Create(this, out var st))
  79. {
  80. NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
  81. sorted = st.sorted;
  82. }
  83. }
  84. /// <summary>
  85. /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
  86. /// </summary>
  87. /// <param name="context"></param>
  88. /// <param name="p">All tokens with probability greater than this will be kept</param>
  89. /// <param name="minKeep"></param>
  90. public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
  91. {
  92. using (LLamaTokenDataArrayNative.Create(this, out var st))
  93. {
  94. NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
  95. sorted = st.sorted;
  96. }
  97. }
  98. /// <summary>
  99. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  100. /// </summary>
  101. /// <param name="context"></param>
  102. /// <param name="z"></param>
  103. /// <param name="min_keep"></param>
  104. public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
  105. {
  106. using (LLamaTokenDataArrayNative.Create(this, out var st))
  107. {
  108. NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
  109. sorted = st.sorted;
  110. }
  111. }
  112. /// <summary>
  113. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  114. /// </summary>
  115. /// <param name="context"></param>
  116. /// <param name="p"></param>
  117. /// <param name="min_keep"></param>
  118. public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
  119. {
  120. using (LLamaTokenDataArrayNative.Create(this, out var st))
  121. {
  122. NativeApi.llama_sample_typical(context, ref st, p, min_keep);
  123. sorted = st.sorted;
  124. }
  125. }
  126. /// <summary>
  127. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  128. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  129. /// </summary>
  130. /// <param name="context"></param>
  131. /// <param name="last_tokens"></param>
  132. /// <param name="penalty_repeat"></param>
  133. /// <param name="penalty_freq"></param>
  134. /// <param name="penalty_present"></param>
  135. public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
  136. {
  137. unsafe
  138. {
  139. using (LLamaTokenDataArrayNative.Create(this, out var st))
  140. using (var last_tokens_handle = last_tokens.Pin())
  141. {
  142. NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
  143. sorted = st.sorted;
  144. }
  145. }
  146. }
  147. /// <summary>
  148. /// Sample with temperature.
  149. /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
  150. /// </summary>
  151. /// <param name="context"></param>
  152. /// <param name="temp"></param>
  153. public void Temperature(SafeLLamaContextHandle context, float temp)
  154. {
  155. using (LLamaTokenDataArrayNative.Create(this, out var st))
  156. {
  157. NativeApi.llama_sample_temperature(context, ref st, temp);
  158. sorted = st.sorted;
  159. }
  160. }
  161. /// <summary>
  162. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  163. /// </summary>
  164. /// <param name="context"></param>
  165. public void Softmax(SafeLLamaContextHandle context)
  166. {
  167. using (LLamaTokenDataArrayNative.Create(this, out var st))
  168. {
  169. NativeApi.llama_sample_softmax(context, ref st);
  170. sorted = st.sorted;
  171. }
  172. }
  173. /// <summary>
  174. /// Randomly selects a token from the candidates based on their probabilities.
  175. /// </summary>
  176. /// <param name="context"></param>
  177. /// <returns></returns>
  178. public int SampleToken(SafeLLamaContextHandle context)
  179. {
  180. using (LLamaTokenDataArrayNative.Create(this, out var st))
  181. {
  182. var token = NativeApi.llama_sample_token(context, ref st);
  183. sorted = st.sorted;
  184. return token;
  185. }
  186. }
  187. /// <summary>
  188. /// Selects the token with the highest probability.
  189. /// </summary>
  190. /// <param name="context"></param>
  191. /// <returns></returns>
  192. public int SampleTokenGreedy(SafeLLamaContextHandle context)
  193. {
  194. using (LLamaTokenDataArrayNative.Create(this, out var st))
  195. {
  196. var token = NativeApi.llama_sample_token_greedy(context, ref st);
  197. sorted = st.sorted;
  198. return token;
  199. }
  200. }
  201. /// <summary>
  202. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  203. /// </summary>
  204. /// <param name="context"></param>
  205. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  206. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  207. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  208. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  209. /// <returns></returns>
  210. public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
  211. {
  212. using (LLamaTokenDataArrayNative.Create(this, out var st))
  213. {
  214. var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
  215. sorted = st.sorted;
  216. return token;
  217. }
  218. }
  219. /// <summary>
  220. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  221. /// </summary>
  222. /// <param name="context"></param>
  223. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  224. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  225. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  226. /// <returns></returns>
  227. public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
  228. {
  229. using (LLamaTokenDataArrayNative.Create(this, out var st))
  230. {
  231. var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
  232. sorted = st.sorted;
  233. return token;
  234. }
  235. }
  236. #endregion
  237. }
  238. /// <summary>
  239. /// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
  240. /// </summary>
  241. [StructLayout(LayoutKind.Sequential)]
  242. public struct LLamaTokenDataArrayNative
  243. {
  244. /// <summary>
  245. /// A pointer to an array of LlamaTokenData
  246. /// </summary>
  247. /// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
  248. public IntPtr data;
  249. /// <summary>
  250. /// Number of LLamaTokenData in the array
  251. /// </summary>
  252. public ulong size;
  253. /// <summary>
  254. /// Indicates if the items in the array are sorted
  255. /// </summary>
  256. public bool sorted
  257. {
  258. get => Convert.ToBoolean(_sorted);
  259. set => _sorted = Convert.ToSByte(value);
  260. }
  261. private sbyte _sorted;
  262. /// <summary>
  263. /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
  264. /// </summary>
  265. /// <param name="array">Data source</param>
  266. /// <param name="native">Created native array</param>
  267. /// <returns>A memory handle, pinning the data in place until disposed</returns>
  268. public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native)
  269. {
  270. var handle = array.data.Pin();
  271. unsafe
  272. {
  273. native = new LLamaTokenDataArrayNative
  274. {
  275. data = new IntPtr(handle.Pointer),
  276. size = (ulong)array.data.Length,
  277. sorted = array.sorted
  278. };
  279. }
  280. return handle;
  281. }
  282. }
  283. }