You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NativeApi.Sampling.cs 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. namespace LLama.Native
  6. {
  7. using llama_token = Int32;
  8. internal unsafe partial class NativeApi
  9. {
  10. /// <summary>
  11. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  12. /// </summary>
  13. /// <param name="ctx"></param>
  14. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  15. /// <param name="last_tokens"></param>
  16. /// <param name="last_tokens_size"></param>
  17. /// <param name="penalty"></param>
  18. [DllImport(libraryName)]
  19. public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty);
  20. /// <summary>
  21. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  22. /// </summary>
  23. /// <param name="ctx"></param>
  24. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  25. /// <param name="last_tokens"></param>
  26. /// <param name="last_tokens_size"></param>
  27. /// <param name="alpha_frequency"></param>
  28. /// <param name="alpha_presence"></param>
  29. [DllImport(libraryName)]
  30. public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
  31. /// <summary>
  32. /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  33. /// </summary>
  34. /// <param name="ctx"></param>
  35. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  36. [DllImport(libraryName)]
  37. public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, IntPtr candidates);
  38. /// <summary>
  39. /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  40. /// </summary>
  41. /// <param name="ctx"></param>
  42. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  43. /// <param name="k"></param>
  44. /// <param name="min_keep"></param>
  45. [DllImport(libraryName)]
  46. public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, IntPtr candidates, int k, ulong min_keep);
  47. /// <summary>
  48. /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  49. /// </summary>
  50. /// <param name="ctx"></param>
  51. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  52. /// <param name="p"></param>
  53. /// <param name="min_keep"></param>
  54. [DllImport(libraryName)]
  55. public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep);
  56. /// <summary>
  57. /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  58. /// </summary>
  59. /// <param name="ctx"></param>
  60. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  61. /// <param name="z"></param>
  62. /// <param name="min_keep"></param>
  63. [DllImport(libraryName)]
  64. public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, IntPtr candidates, float z, ulong min_keep);
  65. /// <summary>
  66. /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  67. /// </summary>
  68. /// <param name="ctx"></param>
  69. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  70. /// <param name="p"></param>
  71. /// <param name="min_keep"></param>
  72. [DllImport(libraryName)]
  73. public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep);
  74. [DllImport(libraryName)]
  75. public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, IntPtr candidates, float temp);
  76. /// <summary>
  77. /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  78. /// </summary>
  79. /// <param name="ctx"></param>
  80. /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  81. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  82. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  83. /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
  84. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  85. /// <returns></returns>
  86. [DllImport(libraryName)]
  87. public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu);
  88. /// <summary>
  89. /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  90. /// </summary>
  91. /// <param name="ctx"></param>
  92. /// <param name="candidates">A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.</param>
  93. /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
  94. /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
  95. /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
  96. /// <returns></returns>
  97. [DllImport(libraryName)]
  98. public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu);
  99. /// <summary>
  100. /// Selects the token with the highest probability.
  101. /// </summary>
  102. /// <param name="ctx"></param>
  103. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  104. /// <returns></returns>
  105. [DllImport(libraryName)]
  106. public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, IntPtr candidates);
  107. /// <summary>
  108. /// Randomly selects a token from the candidates based on their probabilities.
  109. /// </summary>
  110. /// <param name="ctx"></param>
  111. /// <param name="candidates">Pointer to LLamaTokenDataArray</param>
  112. /// <returns></returns>
  113. [DllImport(libraryName)]
  114. public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, IntPtr candidates);
  115. }
  116. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。

Contributors (1)