You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SamplingApi.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. using System;
  2. namespace LLama.Native
  3. {
  4. using llama_token = Int32;
  5. public unsafe class SamplingApi
  6. {
  7. /// <summary>
  8. /// Apply grammar rules to candidate tokens
  9. /// </summary>
  10. /// <param name="ctx"></param>
  11. /// <param name="candidates"></param>
  12. /// <param name="grammar"></param>
  13. public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
  14. {
  15. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  16. NativeApi.llama_sample_grammar(ctx, ref st, grammar);
  17. }
  18. /// <summary>
  19. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  20. /// </summary>
  21. /// <param name="ctx"></param>
  22. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  23. /// <param name="last_tokens"></param>
  24. /// <param name="last_tokens_size"></param>
  25. /// <param name="penalty"></param>
  26. public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty)
  27. {
  28. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  29. NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty);
  30. }
  31. /// <summary>
  32. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  33. /// </summary>
  34. /// <param name="ctx"></param>
  35. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  36. /// <param name="last_tokens"></param>
  37. /// <param name="last_tokens_size"></param>
  38. /// <param name="alpha_frequency"></param>
  39. /// <param name="alpha_presence"></param>
  40. public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
  41. {
  42. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  43. NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
  44. }
  45. /// <summary>
  46. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  47. /// </summary>
  48. /// <param name="ctx"></param>
  49. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  50. public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  51. {
  52. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  53. NativeApi.llama_sample_softmax(ctx, ref st);
  54. }
  55. /// <summary>
  56. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  57. /// </summary>
  58. /// <param name="ctx"></param>
  59. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  60. /// <param name="k"></param>
  61. /// <param name="min_keep"></param>
  62. public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
  63. {
  64. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  65. NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
  66. }
  67. /// <summary>
  68. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  69. /// </summary>
  70. /// <param name="ctx"></param>
  71. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  72. /// <param name="p"></param>
  73. /// <param name="min_keep"></param>
  74. public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  75. {
  76. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  77. NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
  78. }
  79. /// <summary>
  80. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  81. /// </summary>
  82. /// <param name="ctx"></param>
  83. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  84. /// <param name="z"></param>
  85. /// <param name="min_keep"></param>
  86. public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
  87. {
  88. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  89. NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
  90. }
  91. /// <summary>
  92. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  93. /// </summary>
  94. /// <param name="ctx"></param>
  95. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  96. /// <param name="p"></param>
  97. /// <param name="min_keep"></param>
  98. public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  99. {
  100. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  101. NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
  102. }
  103. public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
  104. {
  105. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  106. NativeApi.llama_sample_temperature(ctx, ref st, temp);
  107. }
  108. /// <summary>
  109. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  110. /// </summary>
  111. /// <param name="ctx"></param>
  112. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  113. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  114. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  115. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  116. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  117. /// <returns></returns>
  118. public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
  119. {
  120. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  121. fixed(float* pmu = &mu)
  122. {
  123. return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu);
  124. }
  125. }
  126. /// <summary>
  127. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  128. /// </summary>
  129. /// <param name="ctx"></param>
  130. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  131. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  132. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  133. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  134. /// <returns></returns>
  135. public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
  136. {
  137. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  138. fixed (float* pmu = &mu)
  139. {
  140. return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu);
  141. }
  142. }
  143. /// <summary>
  144. /// Selects the token with the highest probability.
  145. /// </summary>
  146. /// <param name="ctx"></param>
  147. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  148. /// <returns></returns>
  149. public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  150. {
  151. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  152. return NativeApi.llama_sample_token_greedy(ctx, ref st);
  153. }
  154. /// <summary>
  155. /// Randomly selects a token from the candidates based on their probabilities.
  156. /// </summary>
  157. /// <param name="ctx"></param>
  158. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  159. /// <returns></returns>
  160. public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  161. {
  162. using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
  163. return NativeApi.llama_sample_token(ctx, ref st);
  164. }
  165. }
  166. }