You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTokenDataArray.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. using System;
  2. using System.Buffers;
  3. using System.Runtime.InteropServices;
  4. namespace LLama.Native
  5. {
  6. /// <summary>
  7. /// Contains an array of LLamaTokenData, potentially sorted.
  8. /// </summary>
  9. public struct LLamaTokenDataArray
  10. {
  11. /// <summary>
  12. /// The LLamaTokenData
  13. /// </summary>
  14. public readonly Memory<LLamaTokenData> data;
  15. /// <summary>
  16. /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
  17. /// </summary>
  18. public bool sorted;
  19. /// <summary>
  20. /// Create a new LLamaTokenDataArray
  21. /// </summary>
  22. /// <param name="tokens"></param>
  23. /// <param name="isSorted"></param>
  24. public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
  25. {
  26. data = tokens;
  27. sorted = isSorted;
  28. }
  29. /// <summary>
  30. /// Create a new LLamaTokenDataArray, copying the data from the given logits
  31. /// </summary>
  32. /// <param name="logits"></param>
  33. /// <returns></returns>
  34. public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
  35. {
  36. var candidates = new LLamaTokenData[logits.Length];
  37. for (var token_id = 0; token_id < logits.Length; token_id++)
  38. candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f);
  39. return new LLamaTokenDataArray(candidates);
  40. }
  41. /// <summary>
  42. /// Overwrite the logit values for all given tokens
  43. /// </summary>
  44. /// <param name="values">tuples of token and logit value to overwrite</param>
  45. public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values)
  46. {
  47. if (values.Length == 0)
  48. return;
  49. var dataSpan = data.Span;
  50. foreach (var (token, value) in values)
  51. {
  52. for (var i = 0; i < data.Length; i++)
  53. {
  54. if (dataSpan[i].id == token)
  55. {
  56. dataSpan[i].logit = value;
  57. break;
  58. }
  59. }
  60. }
  61. sorted = false;
  62. }
  63. #region sampling
  64. /// <summary>
  65. /// Apply grammar rules to candidate tokens
  66. /// </summary>
  67. /// <param name="ctx"></param>
  68. /// <param name="grammar"></param>
  69. public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
  70. {
  71. if (grammar == null)
  72. return;
  73. using (LLamaTokenDataArrayNative.Create(this, out var st))
  74. {
  75. NativeApi.llama_sample_grammar(ctx, ref st, grammar);
  76. sorted = st.sorted;
  77. }
  78. }
  79. /// <summary>
  80. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  81. /// </summary>
  82. /// <param name="context"></param>
  83. /// <param name="k">Number of tokens to keep</param>
  84. /// <param name="minKeep">Minimum number to keep</param>
  85. public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
  86. {
  87. using (LLamaTokenDataArrayNative.Create(this, out var st))
  88. {
  89. NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
  90. sorted = st.sorted;
  91. }
  92. }
  93. /// <summary>
  94. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  95. /// </summary>
  96. /// <param name="context"></param>
  97. /// <param name="p"></param>
  98. /// <param name="minKeep"></param>
  99. public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
  100. {
  101. using (LLamaTokenDataArrayNative.Create(this, out var st))
  102. {
  103. NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
  104. sorted = st.sorted;
  105. }
  106. }
  107. /// <summary>
  108. /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
  109. /// </summary>
  110. /// <param name="context"></param>
  111. /// <param name="p">All tokens with probability greater than this will be kept</param>
  112. /// <param name="minKeep"></param>
  113. public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
  114. {
  115. using (LLamaTokenDataArrayNative.Create(this, out var st))
  116. {
  117. NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
  118. sorted = st.sorted;
  119. }
  120. }
  121. /// <summary>
  122. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  123. /// </summary>
  124. /// <param name="context"></param>
  125. /// <param name="z"></param>
  126. /// <param name="min_keep"></param>
  127. public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
  128. {
  129. using (LLamaTokenDataArrayNative.Create(this, out var st))
  130. {
  131. NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
  132. sorted = st.sorted;
  133. }
  134. }
  135. /// <summary>
  136. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  137. /// </summary>
  138. /// <param name="context"></param>
  139. /// <param name="p"></param>
  140. /// <param name="min_keep"></param>
  141. public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
  142. {
  143. using (LLamaTokenDataArrayNative.Create(this, out var st))
  144. {
  145. NativeApi.llama_sample_typical(context, ref st, p, min_keep);
  146. sorted = st.sorted;
  147. }
  148. }
  149. /// <summary>
  150. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  151. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  152. /// </summary>
  153. /// <param name="context"></param>
  154. /// <param name="last_tokens"></param>
  155. /// <param name="penalty_repeat"></param>
  156. /// <param name="penalty_freq"></param>
  157. /// <param name="penalty_present"></param>
  158. public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLamaToken> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
  159. {
  160. unsafe
  161. {
  162. using (LLamaTokenDataArrayNative.Create(this, out var st))
  163. {
  164. fixed (LLamaToken* last_tokens_handle = last_tokens)
  165. {
  166. NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
  167. sorted = st.sorted;
  168. }
  169. }
  170. }
  171. }
  172. /// <summary>
  173. /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
  174. /// </summary>
  175. /// <param name="context"></param>
  176. /// <param name="guidanceLogits">Logits extracted from a separate context from the same model.
  177. /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
  178. /// <param name="guidance">Guidance strength. 0 means no guidance, higher values applies stronger guidance</param>
  179. public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanceLogits, float guidance)
  180. {
  181. if (guidanceLogits.Length != data.Length)
  182. throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits));
  183. if (guidance < 0)
  184. throw new ArgumentOutOfRangeException(nameof(guidance), "Guidance strength must be greater than or equal to zero");
  185. // this method accepts 0 (no guidance), higher means more. llama.cpp expects 1 (no guidance), higher means more
  186. // Add one to move up to the llama.cpp baseline.
  187. guidance += 1;
  188. // We need logits array, which we don't have at this point.
  189. // Copy them to a temporary array, apply guidance, then copy them back.
  190. var logits = ArrayPool<float>.Shared.Rent(context.VocabCount);
  191. try
  192. {
  193. // Copy logits into a temporary array
  194. for (var i = 0; i < data.Length; i++)
  195. {
  196. ref var item = ref data.Span[i];
  197. logits[(int)item.id] = item.logit;
  198. }
  199. // Apply guidance
  200. NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);
  201. // Copy logits back into data array
  202. for (var i = 0; i < data.Length; i++)
  203. {
  204. ref var item = ref data.Span[i];
  205. item.logit = logits[(int)item.id];
  206. }
  207. // No longer sorted since we just mutated logits!
  208. sorted = false;
  209. }
  210. finally
  211. {
  212. ArrayPool<float>.Shared.Return(logits);
  213. }
  214. }
  215. /// <summary>
  216. /// Sample with temperature.
  217. /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
  218. /// </summary>
  219. /// <param name="context"></param>
  220. /// <param name="temp"></param>
  221. public void Temperature(SafeLLamaContextHandle context, float temp)
  222. {
  223. using (LLamaTokenDataArrayNative.Create(this, out var st))
  224. {
  225. NativeApi.llama_sample_temp(context, ref st, temp);
  226. sorted = st.sorted;
  227. }
  228. }
  229. /// <summary>
  230. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  231. /// </summary>
  232. /// <param name="context"></param>
  233. public void Softmax(SafeLLamaContextHandle context)
  234. {
  235. using (LLamaTokenDataArrayNative.Create(this, out var st))
  236. {
  237. NativeApi.llama_sample_softmax(context, ref st);
  238. sorted = st.sorted;
  239. }
  240. }
  241. /// <summary>
  242. /// Randomly selects a token from the candidates based on their probabilities.
  243. /// </summary>
  244. /// <param name="context"></param>
  245. /// <returns></returns>
  246. public LLamaToken SampleToken(SafeLLamaContextHandle context)
  247. {
  248. using (LLamaTokenDataArrayNative.Create(this, out var st))
  249. {
  250. var token = NativeApi.llama_sample_token(context, ref st);
  251. sorted = st.sorted;
  252. return token;
  253. }
  254. }
  255. /// <summary>
  256. /// Selects the token with the highest probability.
  257. /// </summary>
  258. /// <param name="context"></param>
  259. /// <returns></returns>
  260. public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context)
  261. {
  262. using (LLamaTokenDataArrayNative.Create(this, out var st))
  263. {
  264. var token = NativeApi.llama_sample_token_greedy(context, ref st);
  265. sorted = st.sorted;
  266. return token;
  267. }
  268. }
  269. /// <summary>
  270. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  271. /// </summary>
  272. /// <param name="context"></param>
  273. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  274. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  275. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  276. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  277. /// <returns></returns>
  278. public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
  279. {
  280. using (LLamaTokenDataArrayNative.Create(this, out var st))
  281. {
  282. var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
  283. sorted = st.sorted;
  284. return token;
  285. }
  286. }
  287. /// <summary>
  288. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  289. /// </summary>
  290. /// <param name="context"></param>
  291. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  292. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  293. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  294. /// <returns></returns>
  295. public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
  296. {
  297. using (LLamaTokenDataArrayNative.Create(this, out var st))
  298. {
  299. var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
  300. sorted = st.sorted;
  301. return token;
  302. }
  303. }
  304. #endregion
  305. }
  306. /// <summary>
  307. /// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
  308. /// </summary>
  309. [StructLayout(LayoutKind.Sequential)]
  310. public struct LLamaTokenDataArrayNative
  311. {
  312. /// <summary>
  313. /// A pointer to an array of LlamaTokenData
  314. /// </summary>
  315. /// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
  316. public IntPtr data;
  317. /// <summary>
  318. /// Number of LLamaTokenData in the array
  319. /// </summary>
  320. public ulong size;
  321. /// <summary>
  322. /// Indicates if the items in the array are sorted
  323. /// </summary>
  324. public bool sorted
  325. {
  326. get => Convert.ToBoolean(_sorted);
  327. set => _sorted = Convert.ToSByte(value);
  328. }
  329. private sbyte _sorted;
  330. /// <summary>
  331. /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
  332. /// </summary>
  333. /// <param name="array">Data source</param>
  334. /// <param name="native">Created native array</param>
  335. /// <returns>A memory handle, pinning the data in place until disposed</returns>
  336. public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native)
  337. {
  338. var handle = array.data.Pin();
  339. unsafe
  340. {
  341. native = new LLamaTokenDataArrayNative
  342. {
  343. data = new IntPtr(handle.Pointer),
  344. size = (ulong)array.data.Length,
  345. sorted = array.sorted
  346. };
  347. }
  348. return handle;
  349. }
  350. }
  351. }