You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. using LLama.Exceptions;
  2. using LLama.Native;
  3. using LLama.OldVersion;
  4. using LLama.Extensions;
  5. using System;
  6. using System.Collections.Generic;
  7. using System.Linq;
  8. using System.Text;
  9. using System.Threading;
  10. using System.IO;
  11. using LLama.Common;
  12. namespace LLama
  13. {
  14. using llama_token = Int32;
  15. /// <summary>
  16. /// The abstraction of a LLama model, which holds the context in the native library.
  17. /// </summary>
  18. public class LLamaModel: IDisposable
  19. {
  20. // TODO: expose more properties.
  21. ILLamaLogger? _logger;
  22. Encoding _encoding;
  23. SafeLLamaContextHandle _ctx;
  24. /// <summary>
  25. /// The context size.
  26. /// </summary>
  27. public int ContextSize { get; }
  28. /// <summary>
  29. /// The model params set for this model.
  30. /// </summary>
  31. public ModelParams Params { get; set; }
  32. /// <summary>
  33. /// The native handle, which is used to be passed to the native APIs. Please avoid using it
  34. /// unless you know what is the usage of the Native API.
  35. /// </summary>
  36. public SafeLLamaContextHandle NativeHandle => _ctx;
  37. /// <summary>
  38. /// The encoding set for this model to deal with text input.
  39. /// </summary>
  40. public Encoding Encoding => _encoding;
  41. /// <summary>
  42. ///
  43. /// </summary>
  44. /// <param name="Params">Model params.</param>
  45. /// <param name="encoding">Encoding to deal with text input.</param>
  46. /// <param name="logger">The logger.</param>
  47. public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
  48. {
  49. _logger = logger;
  50. this.Params = Params;
  51. _encoding = Encoding.GetEncoding(encoding);
  52. _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
  53. _ctx = Utils.InitLLamaContextFromModelParams(this.Params);
  54. ContextSize = NativeApi.llama_n_ctx(_ctx);
  55. }
  56. /// <summary>
  57. /// Tokenize a string.
  58. /// </summary>
  59. /// <param name="text"></param>
  60. /// <param name="addBos">Whether to add a bos to the text.</param>
  61. /// <returns></returns>
  62. public IEnumerable<llama_token> Tokenize(string text, bool addBos = true)
  63. {
  64. // TODO: reconsider whether to convert to array here.
  65. return Utils.Tokenize(_ctx, text, addBos, _encoding);
  66. }
  67. /// <summary>
  68. /// Detokenize the tokens to text.
  69. /// </summary>
  70. /// <param name="tokens"></param>
  71. /// <returns></returns>
  72. public string DeTokenize(IEnumerable<llama_token> tokens)
  73. {
  74. StringBuilder sb = new();
  75. foreach(var token in tokens)
  76. {
  77. sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding));
  78. }
  79. return sb.ToString();
  80. }
  81. /// <summary>
  82. /// Save the state to specified path.
  83. /// </summary>
  84. /// <param name="filename"></param>
  85. public void SaveState(string filename)
  86. {
  87. File.WriteAllBytes(filename, GetStateData());
  88. }
  89. /// <summary>
  90. /// Get the state data as a byte array.
  91. /// </summary>
  92. /// <returns></returns>
  93. public byte[] GetStateData()
  94. {
  95. var stateSize = NativeApi.llama_get_state_size(_ctx);
  96. byte[] stateMemory = new byte[stateSize];
  97. NativeApi.llama_copy_state_data(_ctx, stateMemory);
  98. return stateMemory;
  99. }
  100. /// <summary>
  101. /// Load the state from specified path.
  102. /// </summary>
  103. /// <param name="filename"></param>
  104. /// <exception cref="RuntimeError"></exception>
  105. public void LoadState(string filename)
  106. {
  107. var stateMemory = File.ReadAllBytes(filename);
  108. LoadState(stateMemory);
  109. }
  110. /// <summary>
  111. /// Load the state from memory.
  112. /// </summary>
  113. /// <param name="stateData"></param>
  114. /// <exception cref="RuntimeError"></exception>
  115. public void LoadState(byte[] stateData)
  116. {
  117. int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
  118. if (stateData.Length != stateSize)
  119. {
  120. throw new RuntimeError("Failed to validate state size.");
  121. }
  122. NativeApi.llama_set_state_data(_ctx, stateData);
  123. }
  124. /// <summary>
  125. /// Perform the sampling. Please don't use it unless you fully know what it does.
  126. /// </summary>
  127. /// <param name="candidates"></param>
  128. /// <param name="temperature"></param>
  129. /// <param name="mirostat"></param>
  130. /// <param name="mirostatTau"></param>
  131. /// <param name="mirostatEta"></param>
  132. /// <param name="topK"></param>
  133. /// <param name="topP"></param>
  134. /// <param name="tfsZ"></param>
  135. /// <param name="typicalP"></param>
  136. /// <returns></returns>
  137. public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
  138. float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
  139. {
  140. llama_token id = 0;
  141. if (temperature <= 0)
  142. {
  143. // Greedy sampling
  144. id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
  145. }
  146. else
  147. {
  148. if (mirostat == MiroStateType.MiroState)
  149. {
  150. float mirostat_mu = 2.0f * mirostatTau;
  151. const int mirostat_m = 100;
  152. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  153. id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
  154. }
  155. else if (mirostat == MiroStateType.MiroState2)
  156. {
  157. float mirostat_mu = 2.0f * mirostatTau;
  158. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  159. id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
  160. }
  161. else
  162. {
  163. // Temperature sampling
  164. SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
  165. SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
  166. SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
  167. SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
  168. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  169. id = SamplingApi.llama_sample_token(_ctx, candidates);
  170. }
  171. }
  172. return id;
  173. }
  174. /// <summary>
  175. /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
  176. /// </summary>
  177. /// <param name="lastTokens"></param>
  178. /// <param name="logitBias"></param>
  179. /// <param name="repeatLastTokensCount"></param>
  180. /// <param name="repeatPenalty"></param>
  181. /// <param name="alphaFrequency"></param>
  182. /// <param name="alphaPresence"></param>
  183. /// <param name="penalizeNL"></param>
  184. /// <returns></returns>
  185. public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
  186. int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
  187. bool penalizeNL = true)
  188. {
  189. var n_vocab = NativeApi.llama_n_vocab(_ctx);
  190. var logits = Utils.GetLogits(_ctx, n_vocab);
  191. // Apply params.logit_bias map
  192. if(logitBias is not null)
  193. {
  194. foreach (var (key, value) in logitBias)
  195. {
  196. logits[key] += value;
  197. }
  198. }
  199. var candidates = new List<LLamaTokenData>();
  200. candidates.Capacity = n_vocab;
  201. for (llama_token token_id = 0; token_id < n_vocab; token_id++)
  202. {
  203. candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f));
  204. }
  205. LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false);
  206. // Apply penalties
  207. float nl_logit = logits[NativeApi.llama_token_nl()];
  208. int lastTokensCount = lastTokens.Count();
  209. var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize);
  210. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
  211. lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
  212. (ulong)last_n_repeat, repeatPenalty);
  213. SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
  214. lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
  215. (ulong)last_n_repeat, alphaFrequency, alphaPresence);
  216. if (!penalizeNL)
  217. {
  218. logits[NativeApi.llama_token_nl()] = nl_logit;
  219. }
  220. return candidates_p;
  221. }
  222. /// <summary>
  223. ///
  224. /// </summary>
  225. /// <param name="tokens"></param>
  226. /// <param name="pastTokensCount"></param>
  227. /// <returns>The updated `pastTokensCount`.</returns>
  228. /// <exception cref="RuntimeError"></exception>
  229. public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
  230. {
  231. int total = tokens.Length;
  232. for(int i = 0; i < total; i += Params.BatchSize)
  233. {
  234. int n_eval = total - i;
  235. if(n_eval > Params.BatchSize)
  236. {
  237. n_eval = Params.BatchSize;
  238. }
  239. if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
  240. {
  241. _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
  242. throw new RuntimeError("Failed to eval.");
  243. }
  244. pastTokensCount += n_eval;
  245. }
  246. return pastTokensCount;
  247. }
  248. // TODO: add comment
  249. internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
  250. {
  251. foreach(var id in ids)
  252. {
  253. yield return Utils.TokenToString(id, _ctx, _encoding);
  254. }
  255. }
  256. /// <summary>
  257. ///
  258. /// </summary>
  259. public void Dispose()
  260. {
  261. _ctx.Dispose();
  262. }
  263. }
  264. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。