You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803
  1. using LLama.Exceptions;
  2. using LLama.Types;
  3. using LLama.Extensions;
  4. using LLama.Native;
  5. using System;
  6. using System.Collections.Generic;
  7. using System.Diagnostics;
  8. using System.IO;
  9. using System.Linq;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. public class LLamaModel : IChatModel, IDisposable
  14. {
  15. LLamaParams _params;
  16. SafeLLamaContextHandle _ctx;
  17. string _path_session;
  18. List<llama_token> _session_tokens;
  19. List<llama_token> _embed_inp;
  20. int _n_ctx;
  21. List<llama_token> _inp_pfx;
  22. List<llama_token> _inp_sfx;
  23. List<llama_token> _llama_token_newline;
  24. List<llama_token> _last_n_tokens;
  25. bool _is_interacting;
  26. bool _is_antiprompt;
  27. bool _input_echo;
  28. bool _verbose;
  29. // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
  30. // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
  31. // initial prompt so it doesn't need to be an exact match.
  32. bool _need_to_save_session;
  33. int _n_past;
  34. int _n_remain;
  35. int _n_consumed;
  36. int _n_session_consumed;
  37. List<llama_token> _embed;
  38. public string Name { get; set; }
  39. public bool Verbose
  40. {
  41. get
  42. {
  43. return _verbose;
  44. }
  45. set
  46. {
  47. _verbose = value;
  48. }
  49. }
  50. public SafeLLamaContextHandle NativeHandle => _ctx;
  51. /// <summary>
  52. /// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
  53. /// load 20 layers to gpu by default.
  54. /// </summary>
  55. /// <param name="model_path">The model file path.</param>
  56. /// <param name="model_name">The model name.</param>
  57. /// <param name="verbose">Whether to print details when running the model.</param>
  58. /// <param name="seed"></param>
  59. /// <param name="n_threads"></param>
  60. /// <param name="n_predict"></param>
  61. /// <param name="n_ctx"></param>
  62. /// <param name="n_batch"></param>
  63. /// <param name="n_keep"></param>
  64. /// <param name="n_gpu_layers"></param>
  65. /// <param name="logit_bias"></param>
  66. /// <param name="top_k"></param>
  67. /// <param name="top_p"></param>
  68. /// <param name="tfs_z"></param>
  69. /// <param name="typical_p"></param>
  70. /// <param name="temp"></param>
  71. /// <param name="repeat_penalty"></param>
  72. /// <param name="repeat_last_n"></param>
  73. /// <param name="frequency_penalty"></param>
  74. /// <param name="presence_penalty"></param>
  75. /// <param name="mirostat"></param>
  76. /// <param name="mirostat_tau"></param>
  77. /// <param name="mirostat_eta"></param>
  78. /// <param name="prompt"></param>
  79. /// <param name="path_session"></param>
  80. /// <param name="input_prefix"></param>
  81. /// <param name="input_suffix"></param>
  82. /// <param name="antiprompt"></param>
  83. /// <param name="lora_adapter"></param>
  84. /// <param name="lora_base"></param>
  85. /// <param name="memory_f16"></param>
  86. /// <param name="random_prompt"></param>
  87. /// <param name="use_color"></param>
  88. /// <param name="interactive"></param>
  89. /// <param name="embedding"></param>
  90. /// <param name="interactive_first"></param>
  91. /// <param name="prompt_cache_all"></param>
  92. /// <param name="instruct"></param>
  93. /// <param name="penalize_nl"></param>
  94. /// <param name="perplexity"></param>
  95. /// <param name="use_mmap"></param>
  96. /// <param name="use_mlock"></param>
  97. /// <param name="mem_test"></param>
  98. /// <param name="verbose_prompt"></param>
  99. /// <param name="encoding"></param>
  100. public LLamaModel(string model_path, string model_name, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1,
  101. int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1,
  102. Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
  103. float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
  104. int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
  105. int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, string prompt = "",
  106. string path_session = "", string input_prefix = "", string input_suffix = "",
  107. List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
  108. bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
  109. bool embedding = false, bool interactive_first = false, bool prompt_cache_all = false, bool instruct = false, bool penalize_nl = true,
  110. bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
  111. bool verbose_prompt = false, string encoding = "UTF-8") : this(new LLamaParams(seed: seed,
  112. n_threads: n_threads,
  113. n_predict: n_predict,
  114. n_ctx: n_ctx,
  115. n_batch: n_batch,
  116. n_keep: n_keep,
  117. n_gpu_layers: n_gpu_layers,
  118. logit_bias: logit_bias,
  119. top_k: top_k,
  120. top_p: top_p,
  121. tfs_z: tfs_z,
  122. typical_p: typical_p,
  123. temp: temp,
  124. repeat_penalty: repeat_penalty,
  125. repeat_last_n: repeat_last_n,
  126. frequency_penalty: frequency_penalty,
  127. presence_penalty: presence_penalty,
  128. mirostat: mirostat,
  129. mirostat_tau: mirostat_tau,
  130. mirostat_eta: mirostat_eta,
  131. model: model_path,
  132. prompt: prompt,
  133. path_session: path_session,
  134. input_prefix: input_prefix,
  135. input_suffix: input_suffix,
  136. antiprompt: antiprompt,
  137. lora_adapter: lora_adapter,
  138. lora_base: lora_base,
  139. memory_f16: memory_f16,
  140. random_prompt: random_prompt,
  141. use_color: use_color,
  142. interactive: interactive,
  143. embedding: embedding,
  144. interactive_first: interactive_first,
  145. prompt_cache_all: prompt_cache_all,
  146. instruct: instruct,
  147. penalize_nl: penalize_nl,
  148. perplexity: perplexity,
  149. use_mmap: use_mmap,
  150. use_mlock: use_mlock,
  151. mem_test: mem_test,
  152. verbose_prompt: verbose_prompt),
  153. model_name, verbose, encoding)
  154. {
  155. }
  156. /// <summary>
  157. /// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
  158. /// load 20 layers to gpu by default.
  159. /// </summary>
  160. /// <param name="params">The LLamaModel params</param>
  161. /// <param name="name">Model name</param>
  162. /// <param name="verbose">Whether to output the detailed info.</param>
  163. /// <param name="encoding"></param>
  164. /// <exception cref="RuntimeError"></exception>
  165. public unsafe LLamaModel(LLamaParams @params, string name = "", bool verbose = false, string encoding = "UTF-8")
  166. {
  167. Name = name;
  168. _params = @params;
  169. _verbose = verbose;
  170. _ctx = Utils.llama_init_from_gpt_params(ref _params);
  171. // Add a space in front of the first character to match OG llama tokenizer behavior
  172. _session_tokens = new List<llama_token>();
  173. _path_session = @params.path_session;
  174. if (!string.IsNullOrEmpty(_path_session))
  175. {
  176. if (verbose)
  177. {
  178. LLamaLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
  179. }
  180. if (!File.Exists(_path_session))
  181. {
  182. LLamaLogger.Default.Warn("Session file does not exist, will create.");
  183. }
  184. llama_token[] session_tokens = new llama_token[@params.n_ctx];
  185. ulong n_token_count_out = 0;
  186. if (!NativeApi.llama_load_session_file(_ctx, _path_session, session_tokens, (ulong)@params.n_ctx, &n_token_count_out))
  187. {
  188. throw new RuntimeError($"Failed to load session file {_path_session}");
  189. }
  190. _session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
  191. if (verbose)
  192. {
  193. LLamaLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
  194. }
  195. }
  196. _n_ctx = NativeApi.llama_n_ctx(_ctx);
  197. WithPrompt(_params.prompt);
  198. // prefix & suffix for instruct mode
  199. _inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true, encoding);
  200. _inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false, encoding);
  201. // in instruct mode, we inject a prefix and a suffix to each input by the user
  202. if (_params.instruct)
  203. {
  204. _params.interactive_first = true;
  205. _params.antiprompt.Add("### Instruction:\n\n");
  206. }
  207. // enable interactive mode if reverse prompt or interactive start is specified
  208. if (_params.interactive_first)
  209. {
  210. _params.interactive = true;
  211. }
  212. // determine newline token
  213. _llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false, encoding);
  214. if (_params.verbose_prompt)
  215. {
  216. LLamaLogger.Default.Info("\n");
  217. LLamaLogger.Default.Info($"prompt: '{_params.prompt}'");
  218. LLamaLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
  219. for (int i = 0; i < _embed_inp.Count; i++)
  220. {
  221. LLamaLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
  222. }
  223. if (_params.n_keep > 0)
  224. {
  225. LLamaLogger.Default.Info($"static prompt based on n_keep: '");
  226. for (int i = 0; i < _params.n_keep; i++)
  227. {
  228. LLamaLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
  229. }
  230. LLamaLogger.Default.Info("\n");
  231. }
  232. LLamaLogger.Default.Info("\n");
  233. }
  234. if (_params.interactive && verbose)
  235. {
  236. LLamaLogger.Default.Info("interactive mode on.");
  237. }
  238. if (verbose)
  239. {
  240. LLamaLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
  241. $"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
  242. $"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
  243. $" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
  244. $" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
  245. LLamaLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
  246. $"n_keep = {_params.n_keep}");
  247. LLamaLogger.Default.Info("\n");
  248. }
  249. _last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();
  250. if (_params.interactive)
  251. {
  252. if (verbose)
  253. {
  254. LLamaLogger.Default.Info("== Running in interactive mode. ==");
  255. }
  256. _is_interacting = _params.interactive_first;
  257. }
  258. _is_antiprompt = false;
  259. _input_echo = false;
  260. _n_past = 0;
  261. _n_remain = _params.n_predict;
  262. _n_consumed = 0;
  263. _n_session_consumed = 0;
  264. _embed = new List<llama_token>();
  265. }
  266. /// <summary>
  267. /// Apply a prompt to the model.
  268. /// </summary>
  269. /// <param name="prompt"></param>
  270. /// <param name="encoding"></param>
  271. /// <returns></returns>
  272. /// <exception cref="ArgumentException"></exception>
  273. public LLamaModel WithPrompt(string prompt, string encoding = "UTF-8")
  274. {
  275. _params.prompt = prompt.Insert(0, " ");
  276. _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true, encoding);
  277. if (_embed_inp.Count > _n_ctx - 4)
  278. {
  279. throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
  280. }
  281. ulong n_matching_session_tokens = 0;
  282. if (_session_tokens.Count > 0)
  283. {
  284. foreach (var id in _session_tokens)
  285. {
  286. if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
  287. {
  288. break;
  289. }
  290. n_matching_session_tokens++;
  291. }
  292. if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
  293. {
  294. LLamaLogger.Default.Info("Session file has exact match for prompt!");
  295. }
  296. else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
  297. {
  298. LLamaLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
  299. $"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
  300. }
  301. else
  302. {
  303. LLamaLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
  304. $"tokens of prompt.");
  305. }
  306. }
  307. // number of tokens to keep when resetting context
  308. if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
  309. {
  310. _params.n_keep = _embed_inp.Count;
  311. }
  312. if (_embed_inp.Count > _n_ctx - 4)
  313. {
  314. throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
  315. }
  316. _need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
  317. return this;
  318. }
  319. /// <summary>
  320. /// Apply the prompt file to the model.
  321. /// </summary>
  322. /// <param name="promptFileName"></param>
  323. /// <returns></returns>
  324. public LLamaModel WithPromptFile(string promptFileName)
  325. {
  326. return WithPrompt(File.ReadAllText(promptFileName));
  327. }
  328. private void ProcessTextBeforeInfer(string text, string encoding)
  329. {
  330. if (!string.IsNullOrEmpty(_params.input_prefix))
  331. {
  332. text = _params.input_prefix + text;
  333. }
  334. //if (!text.EndsWith("\n"))
  335. //{
  336. // text += "\n";
  337. //}
  338. if (text.Length > 1)
  339. {
  340. // append input suffix if any
  341. if (!string.IsNullOrEmpty(_params.input_suffix))
  342. {
  343. text += _params.input_suffix;
  344. //yield return _params.input_suffix;
  345. }
  346. // instruct mode: insert instruction prefix
  347. if (_params.instruct && !_is_antiprompt)
  348. {
  349. _n_consumed = _embed_inp.Count;
  350. _embed_inp.AddRange(_inp_pfx);
  351. }
  352. var line_inp = Utils.llama_tokenize(_ctx, text, false, encoding);
  353. _embed_inp.AddRange(line_inp);
  354. // instruct mode: insert response suffix
  355. if (_params.instruct)
  356. {
  357. _embed_inp.AddRange(_inp_sfx);
  358. }
  359. _n_remain -= line_inp.Count;
  360. }
  361. }
  362. public void InitChatPrompt(string prompt, string encoding = "UTF-8")
  363. {
  364. WithPrompt(prompt);
  365. }
  366. public void InitChatAntiprompt(string[] antiprompt)
  367. {
  368. _params.antiprompt = antiprompt.ToList();
  369. }
  370. /// <summary>
  371. /// Chat with the LLaMa model under interactive mode.
  372. /// </summary>
  373. /// <param name="text"></param>
  374. /// <param name="prompt"></param>
  375. /// <param name="encoding"></param>
  376. /// <returns></returns>
  377. /// <exception cref="ArgumentException"></exception>
  378. public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
  379. {
  380. if (!_params.interactive)
  381. {
  382. throw new ArgumentException("The chat API could be only used under interactive model.");
  383. }
  384. _input_echo = false;
  385. if (!string.IsNullOrEmpty(prompt))
  386. {
  387. WithPrompt(prompt);
  388. }
  389. return Call(text, encoding);
  390. }
  391. /// <summary>
  392. /// Save the state to specified path.
  393. /// </summary>
  394. /// <param name="filename"></param>
  395. public void SaveState(string filename)
  396. {
  397. var stateSize = NativeApi.llama_get_state_size(_ctx);
  398. byte[] stateMemory = new byte[stateSize];
  399. NativeApi.llama_copy_state_data(_ctx, stateMemory);
  400. File.WriteAllBytes(filename, stateMemory);
  401. }
  402. /// <summary>
  403. /// Load the state from specified path.
  404. /// </summary>
  405. /// <param name="filename"></param>
  406. /// <param name="clearPreviousEmbed">Whether to clear previous footprints of this model.</param>
  407. /// <exception cref="RuntimeError"></exception>
  408. public void LoadState(string filename, bool clearPreviousEmbed = true)
  409. {
  410. var stateMemory = File.ReadAllBytes(filename);
  411. int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
  412. if (stateMemory.Length != stateSize)
  413. {
  414. throw new RuntimeError("Failed to validate state size.");
  415. }
  416. NativeApi.llama_set_state_data(_ctx, stateMemory);
  417. if (clearPreviousEmbed)
  418. {
  419. WithPrompt(_params.prompt);
  420. }
  421. }
  422. /// <summary>
  423. /// Tokenize a string.
  424. /// </summary>
  425. /// <param name="text">The utf-8 encoded string to tokenize.</param>
  426. /// <returns>A list of tokens.</returns>
  427. /// <exception cref="RuntimeError">If the tokenization failed.</exception>
  428. public List<llama_token> Tokenize(string text)
  429. {
  430. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  431. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  432. var tokens = new llama_token[n_ctx];
  433. var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
  434. if (n_tokens < 0)
  435. {
  436. throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
  437. }
  438. return tokens.Take(n_tokens).ToList();
  439. }
  440. /// <summary>
  441. /// Detokenize a list of tokens.
  442. /// </summary>
  443. /// <param name="tokens">The list of tokens to detokenize.</param>
  444. /// <returns>The detokenized string.</returns>
  445. public string DeTokenize(IEnumerable<llama_token> tokens)
  446. {
  447. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  448. string output = "";
  449. foreach (var token in tokens)
  450. {
  451. output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
  452. }
  453. return output;
  454. }
  455. /// <summary>
  456. /// Call the model to run inference.
  457. /// </summary>
  458. /// <param name="text"></param>
  459. /// <param name="encoding"></param>
  460. /// <returns></returns>
  461. /// <exception cref="RuntimeError"></exception>
  462. public IEnumerable<string> Call(string text, string encoding = "UTF-8")
  463. {
  464. _is_antiprompt = false;
  465. if(_n_past > 0)
  466. {
  467. _is_interacting = false;
  468. }
  469. if (_is_interacting)
  470. {
  471. if (_verbose)
  472. {
  473. LLamaLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
  474. }
  475. _is_interacting = false;
  476. }
  477. ProcessTextBeforeInfer(text, encoding);
  478. while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
  479. {
  480. if (_embed.Count > 0)
  481. {
  482. // infinite text generation via context swapping
  483. // if we run out of context:
  484. // - take the n_keep first tokens from the original prompt (via n_past)
  485. // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
  486. if (_n_past + _embed.Count > _n_ctx)
  487. {
  488. int n_left = _n_past - _params.n_keep;
  489. _n_past = Math.Max(1, _params.n_keep);
  490. // insert n_left/2 tokens at the start of embed from last_n_tokens
  491. _embed.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embed.Count).Skip(_n_ctx - n_left / 2 - _embed.Count));
  492. // stop saving session if we run out of context
  493. _path_session = "";
  494. }
  495. // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
  496. // REVIEW
  497. if (_n_session_consumed < _session_tokens.Count)
  498. {
  499. int i = 0;
  500. for (; i < _embed.Count; i++)
  501. {
  502. if (_embed[i] != _session_tokens[_n_session_consumed])
  503. {
  504. _session_tokens = _session_tokens.Take(_n_session_consumed).ToList();
  505. break;
  506. }
  507. _n_past++;
  508. _n_session_consumed++;
  509. if (_n_session_consumed >= _session_tokens.Count)
  510. {
  511. i++;
  512. break;
  513. }
  514. }
  515. if (i > 0)
  516. {
  517. _embed.RemoveRange(0, i);
  518. }
  519. }
  520. // evaluate tokens in batches
  521. // embed is typically prepared beforehand to fit within a batch, but not always
  522. for (int i = 0; i < _embed.Count; i += _params.n_batch)
  523. {
  524. int n_eval = _embed.Count - i;
  525. if (n_eval > _params.n_batch)
  526. {
  527. n_eval = _params.n_batch;
  528. }
  529. var array = _embed.Skip(i).ToArray();
  530. if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
  531. {
  532. LLamaLogger.Default.Error($"Failed to eval.");
  533. throw new RuntimeError("Failed to eval.");
  534. }
  535. _n_past += n_eval;
  536. }
  537. if (_embed.Count > 0 && !string.IsNullOrEmpty(_path_session))
  538. {
  539. _session_tokens.AddRange(_embed);
  540. _n_session_consumed = _session_tokens.Count;
  541. }
  542. }
  543. _embed.Clear();
  544. if (_embed_inp.Count <= _n_consumed && !_is_interacting)
  545. {
  546. var temp = _params.temp;
  547. var top_k = _params.top_k <= 0 ? NativeApi.llama_n_vocab(_ctx) : _params.top_k;
  548. var top_p = _params.top_p;
  549. var tfs_z = _params.tfs_z;
  550. var typical_p = _params.typical_p;
  551. var repeat_last_n = _params.repeat_last_n < 0 ? _n_ctx : _params.repeat_last_n;
  552. var repeat_penalty = _params.repeat_penalty;
  553. var alpha_presence = _params.presence_penalty;
  554. var alpha_frequency = _params.frequency_penalty;
  555. var mirostat = _params.mirostat;
  556. var mirostat_tau = _params.mirostat_tau;
  557. var mirostat_eta = _params.mirostat_eta;
  558. var penalize_nl = _params.penalize_nl;
  559. // optionally save the session on first sample (for faster prompt loading next time)
  560. if (!string.IsNullOrEmpty(_path_session) && _need_to_save_session)
  561. {
  562. _need_to_save_session = false;
  563. NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count);
  564. }
  565. llama_token id = 0;
  566. {
  567. var n_vocab = NativeApi.llama_n_vocab(_ctx);
  568. var logits = Utils.llama_get_logits(_ctx, n_vocab);
  569. // Apply params.logit_bias map
  570. foreach (var (key, value) in _params.logit_bias)
  571. {
  572. logits[key] += value;
  573. }
  574. var candidates = new List<LLamaTokenData>();
  575. candidates.Capacity = n_vocab;
  576. for (llama_token token_id = 0; token_id < n_vocab; token_id++)
  577. {
  578. candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f));
  579. }
  580. LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false);
  581. // Apply penalties
  582. float nl_logit = logits[NativeApi.llama_token_nl()];
  583. var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
  584. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
  585. _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
  586. (ulong)last_n_repeat, repeat_penalty);
  587. SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
  588. _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
  589. (ulong)last_n_repeat, alpha_frequency, alpha_presence);
  590. if (!penalize_nl)
  591. {
  592. logits[NativeApi.llama_token_nl()] = nl_logit;
  593. }
  594. if (temp <= 0)
  595. {
  596. // Greedy sampling
  597. id = SamplingApi.llama_sample_token_greedy(_ctx, candidates_p);
  598. }
  599. else
  600. {
  601. if (mirostat == 1)
  602. {
  603. float mirostat_mu = 2.0f * mirostat_tau;
  604. const int mirostat_m = 100;
  605. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  606. id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_m, ref mirostat_mu);
  607. }
  608. else if (mirostat == 2)
  609. {
  610. float mirostat_mu = 2.0f * mirostat_tau;
  611. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  612. id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates_p, mirostat_tau, mirostat_eta, ref mirostat_mu);
  613. }
  614. else
  615. {
  616. // Temperature sampling
  617. SamplingApi.llama_sample_top_k(_ctx, candidates_p, top_k, 1);
  618. SamplingApi.llama_sample_tail_free(_ctx, candidates_p, tfs_z, 1);
  619. SamplingApi.llama_sample_typical(_ctx, candidates_p, typical_p, 1);
  620. SamplingApi.llama_sample_top_p(_ctx, candidates_p, top_p, 1);
  621. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  622. id = SamplingApi.llama_sample_token(_ctx, candidates_p);
  623. }
  624. }
  625. _last_n_tokens.RemoveAt(0);
  626. _last_n_tokens.Add(id);
  627. }
  628. // replace end of text token with newline token when in interactive mode
  629. if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
  630. {
  631. id = _llama_token_newline[0];
  632. if (_params.antiprompt.Count != 0)
  633. {
  634. // tokenize and inject first reverse prompt
  635. var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false, encoding);
  636. _embed_inp.AddRange(first_antiprompt);
  637. }
  638. }
  639. // add it to the context
  640. _embed.Add(id);
  641. // echo this to console
  642. _input_echo = true;
  643. // decrement remaining sampling budget
  644. _n_remain--;
  645. }
  646. else
  647. {
  648. while (_embed_inp.Count > _n_consumed)
  649. {
  650. _embed.Add(_embed_inp[_n_consumed]);
  651. _last_n_tokens.RemoveAt(0);
  652. _last_n_tokens.Add(_embed_inp[_n_consumed]);
  653. _n_consumed++;
  654. if (_embed.Count >= _params.n_batch)
  655. {
  656. break;
  657. }
  658. }
  659. }
  660. if (_input_echo && !_is_interacting)
  661. {
  662. foreach (var id in _embed)
  663. {
  664. var res = Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
  665. yield return res;
  666. }
  667. }
  668. if (_params.interactive && _embed_inp.Count <= _n_consumed)
  669. {
  670. if (_params.antiprompt.Count > 0)
  671. {
  672. string last_output = "";
  673. foreach (var id in _last_n_tokens)
  674. {
  675. last_output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
  676. }
  677. _is_antiprompt = false;
  678. foreach (var antiprompt in _params.antiprompt)
  679. {
  680. if (last_output.EndsWith(antiprompt))
  681. {
  682. _is_interacting = true;
  683. _is_antiprompt = true;
  684. break;
  685. }
  686. }
  687. }
  688. if (_n_past > 0 && _is_interacting)
  689. {
  690. if (_params.instruct)
  691. {
  692. yield return "\n> ";
  693. }
  694. _input_echo = false;
  695. break;
  696. }
  697. if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
  698. {
  699. if (_params.instruct)
  700. {
  701. _is_interacting = true;
  702. }
  703. else
  704. {
  705. LLamaLogger.Default.Info(" [end of text]");
  706. }
  707. }
  708. if (_params.interactive && _n_remain <= 0 && _params.n_predict != -1)
  709. {
  710. _n_remain = _params.n_predict;
  711. _is_interacting = true;
  712. }
  713. }
  714. }
  715. if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all)
  716. {
  717. LLamaLogger.Default.Info($"saving final output to session file {_path_session}");
  718. var session_token_array = _session_tokens.ToArray();
  719. NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length);
  720. }
  721. }
  722. public void Dispose()
  723. {
  724. _ctx.Dispose();
  725. }
  726. }
  727. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。