You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SamplingApi.cs 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. using System;
  2. #pragma warning disable IDE1006 // Naming Styles
  3. namespace LLama.Native
  4. {
  5. using llama_token = Int32;
  6. /// <summary>
  7. /// Direct translation of the llama.cpp sampling API
  8. /// </summary>
  9. public class SamplingApi
  10. {
  11. /// <summary>
  12. /// Apply grammar rules to candidate tokens
  13. /// </summary>
  14. /// <param name="ctx"></param>
  15. /// <param name="candidates"></param>
  16. /// <param name="grammar"></param>
  17. [Obsolete("use LLamaTokenDataArray ApplyGrammar method")]
  18. public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
  19. {
  20. candidates.ApplyGrammar(ctx, grammar);
  21. }
  22. /// <summary>
  23. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  24. /// </summary>
  25. /// <param name="ctx"></param>
  26. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  27. [Obsolete("use LLamaTokenDataArray Softmax method")]
  28. public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  29. {
  30. candidates.Softmax(ctx);
  31. }
  32. /// <summary>
  33. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  34. /// </summary>
  35. /// <param name="ctx"></param>
  36. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  37. /// <param name="k"></param>
  38. /// <param name="min_keep"></param>
  39. [Obsolete("use LLamaTokenDataArray TopK method")]
  40. public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
  41. {
  42. candidates.TopK(ctx, k, min_keep);
  43. }
  44. /// <summary>
  45. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  46. /// </summary>
  47. /// <param name="ctx"></param>
  48. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  49. /// <param name="p"></param>
  50. /// <param name="min_keep"></param>
  51. [Obsolete("use LLamaTokenDataArray TopP method")]
  52. public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  53. {
  54. candidates.TopP(ctx, p, min_keep);
  55. }
  56. /// <summary>
  57. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  58. /// </summary>
  59. /// <param name="ctx"></param>
  60. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  61. /// <param name="z"></param>
  62. /// <param name="min_keep"></param>
  63. [Obsolete("use LLamaTokenDataArray TailFree method")]
  64. public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
  65. {
  66. candidates.TailFree(ctx, z, min_keep);
  67. }
  68. /// <summary>
  69. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  70. /// </summary>
  71. /// <param name="ctx"></param>
  72. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  73. /// <param name="p"></param>
  74. /// <param name="min_keep"></param>
  75. [Obsolete("use LLamaTokenDataArray LocallyTypical method")]
  76. public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
  77. {
  78. candidates.LocallyTypical(ctx, p, min_keep);
  79. }
  80. /// <summary>
  81. /// Sample with temperature.
  82. /// As temperature increases, the prediction becomes diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
  83. /// </summary>
  84. /// <param name="ctx"></param>
  85. /// <param name="candidates"></param>
  86. /// <param name="temp"></param>
  87. [Obsolete("use LLamaTokenDataArray Temperature() method")]
  88. public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
  89. {
  90. candidates.Temperature(ctx, temp);
  91. }
  92. /// <summary>
  93. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  94. /// </summary>
  95. /// <param name="ctx"></param>
  96. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  97. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  98. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  99. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  100. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  101. /// <returns></returns>
  102. [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
  103. public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
  104. {
  105. return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
  106. }
  107. /// <summary>
  108. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  109. /// </summary>
  110. /// <param name="ctx"></param>
  111. /// <param name="candidates">A vector of `LLamaTokenData` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  112. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  113. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  114. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  115. /// <returns></returns>
  116. [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
  117. public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
  118. {
  119. return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
  120. }
  121. /// <summary>
  122. /// Selects the token with the highest probability.
  123. /// </summary>
  124. /// <param name="ctx"></param>
  125. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  126. /// <returns></returns>
  127. [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
  128. public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  129. {
  130. return candidates.SampleTokenGreedy(ctx);
  131. }
  132. /// <summary>
  133. /// Randomly selects a token from the candidates based on their probabilities.
  134. /// </summary>
  135. /// <param name="ctx"></param>
  136. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  137. /// <returns></returns>
  138. [Obsolete("use LLamaTokenDataArray SampleToken() method")]
  139. public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  140. {
  141. return candidates.SampleToken(ctx);
  142. }
  143. }
  144. }