You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. using LLama.Exceptions;
  2. using LLama.Extensions;
  3. using LLama.Native;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Diagnostics;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Text;
  10. using LLama.Common;
  11. #pragma warning disable
  12. namespace LLama.OldVersion
  13. {
  14. using llama_token = Int32;
  15. [Obsolete("The entire LLama.OldVersion namespace will be removed")]
  16. public class LLamaModel
  17. : IChatModel, IDisposable
  18. {
  19. LLamaParams _params;
  20. SafeLLamaContextHandle _ctx;
  21. string _path_session;
  22. List<llama_token> _session_tokens;
  23. List<llama_token> _embed_inp;
  24. int _n_ctx;
  25. List<llama_token> _inp_pfx;
  26. List<llama_token> _inp_sfx;
  27. List<llama_token> _llama_token_newline;
  28. List<llama_token> _last_n_tokens;
  29. bool _is_interacting;
  30. bool _is_antiprompt;
  31. bool _input_echo;
  32. bool _verbose;
  33. // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
  34. // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
  35. // initial prompt so it doesn't need to be an exact match.
  36. bool _need_to_save_session;
  37. int _n_past;
  38. int _n_remain;
  39. int _n_consumed;
  40. int _n_session_consumed;
  41. List<llama_token> _embed;
  42. public string Name { get; set; }
  43. public bool Verbose
  44. {
  45. get
  46. {
  47. return _verbose;
  48. }
  49. set
  50. {
  51. _verbose = value;
  52. }
  53. }
  54. public SafeLLamaContextHandle NativeHandle => _ctx;
  55. /// <summary>
  56. /// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
  57. /// load 20 layers to gpu by default.
  58. /// </summary>
  59. /// <param name="model_path">The model file path.</param>
  60. /// <param name="model_name">The model name.</param>
  61. /// <param name="verbose">Whether to print details when running the model.</param>
  62. /// <param name="seed"></param>
  63. /// <param name="n_threads"></param>
  64. /// <param name="n_predict"></param>
  65. /// <param name="n_ctx"></param>
  66. /// <param name="n_batch"></param>
  67. /// <param name="n_keep"></param>
  68. /// <param name="n_gpu_layers"></param>
  69. /// <param name="logit_bias"></param>
  70. /// <param name="top_k"></param>
  71. /// <param name="top_p"></param>
  72. /// <param name="tfs_z"></param>
  73. /// <param name="typical_p"></param>
  74. /// <param name="temp"></param>
  75. /// <param name="repeat_penalty"></param>
  76. /// <param name="repeat_last_n"></param>
  77. /// <param name="frequency_penalty"></param>
  78. /// <param name="presence_penalty"></param>
  79. /// <param name="mirostat"></param>
  80. /// <param name="mirostat_tau"></param>
  81. /// <param name="mirostat_eta"></param>
  82. /// <param name="prompt"></param>
  83. /// <param name="path_session"></param>
  84. /// <param name="input_prefix"></param>
  85. /// <param name="input_suffix"></param>
  86. /// <param name="antiprompt"></param>
  87. /// <param name="lora_adapter"></param>
  88. /// <param name="lora_base"></param>
  89. /// <param name="memory_f16"></param>
  90. /// <param name="random_prompt"></param>
  91. /// <param name="use_color"></param>
  92. /// <param name="interactive"></param>
  93. /// <param name="embedding"></param>
  94. /// <param name="interactive_first"></param>
  95. /// <param name="prompt_cache_all"></param>
  96. /// <param name="instruct"></param>
  97. /// <param name="penalize_nl"></param>
  98. /// <param name="perplexity"></param>
  99. /// <param name="use_mmap"></param>
  100. /// <param name="use_mlock"></param>
  101. /// <param name="mem_test"></param>
  102. /// <param name="verbose_prompt"></param>
  103. /// <param name="encoding"></param>
  104. public LLamaModel(string model_path, string model_name, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1,
  105. int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1,
  106. Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
  107. float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
  108. int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
  109. int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, string prompt = "",
  110. string path_session = "", string input_prefix = "", string input_suffix = "",
  111. List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
  112. bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
  113. bool embedding = false, bool interactive_first = false, bool prompt_cache_all = false, bool instruct = false, bool penalize_nl = true,
  114. bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
  115. bool verbose_prompt = false, string encoding = "UTF-8") : this(new LLamaParams(seed: seed,
  116. n_threads: n_threads,
  117. n_predict: n_predict,
  118. n_ctx: n_ctx,
  119. n_batch: n_batch,
  120. n_keep: n_keep,
  121. n_gpu_layers: n_gpu_layers,
  122. logit_bias: logit_bias,
  123. top_k: top_k,
  124. top_p: top_p,
  125. tfs_z: tfs_z,
  126. typical_p: typical_p,
  127. temp: temp,
  128. repeat_penalty: repeat_penalty,
  129. repeat_last_n: repeat_last_n,
  130. frequency_penalty: frequency_penalty,
  131. presence_penalty: presence_penalty,
  132. mirostat: mirostat,
  133. mirostat_tau: mirostat_tau,
  134. mirostat_eta: mirostat_eta,
  135. model: model_path,
  136. prompt: prompt,
  137. path_session: path_session,
  138. input_prefix: input_prefix,
  139. input_suffix: input_suffix,
  140. antiprompt: antiprompt,
  141. lora_adapter: lora_adapter,
  142. lora_base: lora_base,
  143. memory_f16: memory_f16,
  144. random_prompt: random_prompt,
  145. use_color: use_color,
  146. interactive: interactive,
  147. embedding: embedding,
  148. interactive_first: interactive_first,
  149. prompt_cache_all: prompt_cache_all,
  150. instruct: instruct,
  151. penalize_nl: penalize_nl,
  152. perplexity: perplexity,
  153. use_mmap: use_mmap,
  154. use_mlock: use_mlock,
  155. mem_test: mem_test,
  156. verbose_prompt: verbose_prompt),
  157. model_name, verbose, encoding)
  158. {
  159. }
  160. /// <summary>
  161. /// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
  162. /// load 20 layers to gpu by default.
  163. /// </summary>
  164. /// <param name="params">The LLamaModel params</param>
  165. /// <param name="name">Model name</param>
  166. /// <param name="verbose">Whether to output the detailed info.</param>
  167. /// <param name="encoding"></param>
  168. /// <exception cref="RuntimeError"></exception>
  169. public unsafe LLamaModel(LLamaParams @params, string name = "", bool verbose = false, string encoding = "UTF-8")
  170. {
  171. Name = name;
  172. _params = @params;
  173. _verbose = verbose;
  174. _ctx = Utils.llama_init_from_gpt_params(ref _params);
  175. // Add a space in front of the first character to match OG llama tokenizer behavior
  176. _session_tokens = new List<llama_token>();
  177. _path_session = @params.path_session;
  178. if (!string.IsNullOrEmpty(_path_session))
  179. {
  180. if (verbose)
  181. {
  182. LLamaDefaultLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
  183. }
  184. if (!File.Exists(_path_session))
  185. {
  186. LLamaDefaultLogger.Default.Warn("Session file does not exist, will create.");
  187. }
  188. llama_token[] session_tokens = new llama_token[@params.n_ctx];
  189. ulong n_token_count_out = 0;
  190. if (!NativeApi.llama_load_session_file(_ctx, _path_session, session_tokens, (ulong)@params.n_ctx, &n_token_count_out))
  191. {
  192. throw new RuntimeError($"Failed to load session file {_path_session}");
  193. }
  194. _session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
  195. if (verbose)
  196. {
  197. LLamaDefaultLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
  198. }
  199. }
  200. _n_ctx = NativeApi.llama_n_ctx(_ctx);
  201. WithPrompt(_params.prompt);
  202. // prefix & suffix for instruct mode
  203. _inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true, encoding);
  204. _inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false, encoding);
  205. // in instruct mode, we inject a prefix and a suffix to each input by the user
  206. if (_params.instruct)
  207. {
  208. _params.interactive_first = true;
  209. _params.antiprompt.Add("### Instruction:\n\n");
  210. }
  211. // enable interactive mode if reverse prompt or interactive start is specified
  212. if (_params.interactive_first)
  213. {
  214. _params.interactive = true;
  215. }
  216. // determine newline token
  217. _llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false, encoding);
  218. if (_params.verbose_prompt)
  219. {
  220. LLamaDefaultLogger.Default.Info("\n");
  221. LLamaDefaultLogger.Default.Info($"prompt: '{_params.prompt}'");
  222. LLamaDefaultLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
  223. for (int i = 0; i < _embed_inp.Count; i++)
  224. {
  225. LLamaDefaultLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
  226. }
  227. if (_params.n_keep > 0)
  228. {
  229. LLamaDefaultLogger.Default.Info($"static prompt based on n_keep: '");
  230. for (int i = 0; i < _params.n_keep; i++)
  231. {
  232. LLamaDefaultLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
  233. }
  234. LLamaDefaultLogger.Default.Info("\n");
  235. }
  236. LLamaDefaultLogger.Default.Info("\n");
  237. }
  238. if (_params.interactive && verbose)
  239. {
  240. LLamaDefaultLogger.Default.Info("interactive mode on.");
  241. }
  242. if (verbose)
  243. {
  244. LLamaDefaultLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
  245. $"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
  246. $"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
  247. $" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
  248. $" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
  249. LLamaDefaultLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
  250. $"n_keep = {_params.n_keep}");
  251. LLamaDefaultLogger.Default.Info("\n");
  252. }
  253. _last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();
  254. if (_params.interactive)
  255. {
  256. if (verbose)
  257. {
  258. LLamaDefaultLogger.Default.Info("== Running in interactive mode. ==");
  259. }
  260. _is_interacting = _params.interactive_first;
  261. }
  262. _is_antiprompt = false;
  263. _input_echo = false;
  264. _n_past = 0;
  265. _n_remain = _params.n_predict;
  266. _n_consumed = 0;
  267. _n_session_consumed = 0;
  268. _embed = new List<llama_token>();
  269. }
  270. /// <summary>
  271. /// Apply a prompt to the model.
  272. /// </summary>
  273. /// <param name="prompt"></param>
  274. /// <param name="encoding"></param>
  275. /// <returns></returns>
  276. /// <exception cref="ArgumentException"></exception>
  277. public LLamaModel WithPrompt(string prompt, string encoding = "UTF-8")
  278. {
  279. _params.prompt = prompt.Insert(0, " ");
  280. _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true, encoding);
  281. if (_embed_inp.Count > _n_ctx - 4)
  282. {
  283. throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
  284. }
  285. ulong n_matching_session_tokens = 0;
  286. if (_session_tokens.Count > 0)
  287. {
  288. foreach (var id in _session_tokens)
  289. {
  290. if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
  291. {
  292. break;
  293. }
  294. n_matching_session_tokens++;
  295. }
  296. if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
  297. {
  298. LLamaDefaultLogger.Default.Info("Session file has exact match for prompt!");
  299. }
  300. else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
  301. {
  302. LLamaDefaultLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
  303. $"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
  304. }
  305. else
  306. {
  307. LLamaDefaultLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
  308. $"tokens of prompt.");
  309. }
  310. }
  311. // number of tokens to keep when resetting context
  312. if (_params.n_keep < 0 || _params.n_keep > _embed_inp.Count || _params.instruct)
  313. {
  314. _params.n_keep = _embed_inp.Count;
  315. }
  316. if (_embed_inp.Count > _n_ctx - 4)
  317. {
  318. throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
  319. }
  320. _need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
  321. return this;
  322. }
  323. /// <summary>
  324. /// Apply the prompt file to the model.
  325. /// </summary>
  326. /// <param name="promptFileName"></param>
  327. /// <returns></returns>
  328. public LLamaModel WithPromptFile(string promptFileName)
  329. {
  330. return WithPrompt(File.ReadAllText(promptFileName));
  331. }
  332. private void ProcessTextBeforeInfer(string text, string encoding)
  333. {
  334. if (!string.IsNullOrEmpty(_params.input_prefix))
  335. {
  336. text = _params.input_prefix + text;
  337. }
  338. //if (!text.EndsWith("\n"))
  339. //{
  340. // text += "\n";
  341. //}
  342. if (text.Length > 1)
  343. {
  344. // append input suffix if any
  345. if (!string.IsNullOrEmpty(_params.input_suffix))
  346. {
  347. text += _params.input_suffix;
  348. //yield return _params.input_suffix;
  349. }
  350. // instruct mode: insert instruction prefix
  351. if (_params.instruct && !_is_antiprompt)
  352. {
  353. _n_consumed = _embed_inp.Count;
  354. _embed_inp.AddRange(_inp_pfx);
  355. }
  356. var line_inp = Utils.llama_tokenize(_ctx, text, false, encoding);
  357. _embed_inp.AddRange(line_inp);
  358. // instruct mode: insert response suffix
  359. if (_params.instruct)
  360. {
  361. _embed_inp.AddRange(_inp_sfx);
  362. }
  363. _n_remain -= line_inp.Count;
  364. }
  365. }
  366. public void InitChatPrompt(string prompt, string encoding = "UTF-8")
  367. {
  368. WithPrompt(prompt);
  369. }
  370. public void InitChatAntiprompt(string[] antiprompt)
  371. {
  372. _params.antiprompt = antiprompt.ToList();
  373. }
  374. /// <summary>
  375. /// Chat with the LLaMa model under interactive mode.
  376. /// </summary>
  377. /// <param name="text"></param>
  378. /// <param name="prompt"></param>
  379. /// <param name="encoding"></param>
  380. /// <returns></returns>
  381. /// <exception cref="ArgumentException"></exception>
  382. public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
  383. {
  384. if (!_params.interactive)
  385. {
  386. throw new ArgumentException("The chat API could be only used under interactive model.");
  387. }
  388. _input_echo = false;
  389. if (!string.IsNullOrEmpty(prompt))
  390. {
  391. WithPrompt(prompt);
  392. }
  393. return Call(text, encoding);
  394. }
  395. /// <summary>
  396. /// Save the state to specified path.
  397. /// </summary>
  398. /// <param name="filename"></param>
  399. public void SaveState(string filename)
  400. {
  401. var stateSize = NativeApi.llama_get_state_size(_ctx);
  402. byte[] stateMemory = new byte[stateSize];
  403. NativeApi.llama_copy_state_data(_ctx, stateMemory);
  404. File.WriteAllBytes(filename, stateMemory);
  405. }
  406. /// <summary>
  407. /// Load the state from specified path.
  408. /// </summary>
  409. /// <param name="filename"></param>
  410. /// <param name="clearPreviousEmbed">Whether to clear previous footprints of this model.</param>
  411. /// <exception cref="RuntimeError"></exception>
  412. public void LoadState(string filename, bool clearPreviousEmbed = true)
  413. {
  414. var stateMemory = File.ReadAllBytes(filename);
  415. int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
  416. if (stateMemory.Length != stateSize)
  417. {
  418. throw new RuntimeError("Failed to validate state size.");
  419. }
  420. NativeApi.llama_set_state_data(_ctx, stateMemory);
  421. if (clearPreviousEmbed)
  422. {
  423. WithPrompt(_params.prompt);
  424. }
  425. }
  426. /// <summary>
  427. /// Tokenize a string.
  428. /// </summary>
  429. /// <param name="text">The utf-8 encoded string to tokenize.</param>
  430. /// <returns>A list of tokens.</returns>
  431. /// <exception cref="RuntimeError">If the tokenization failed.</exception>
  432. public List<llama_token> Tokenize(string text, string encoding = "UTF-8")
  433. {
  434. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  435. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  436. var tokens = new llama_token[n_ctx];
  437. var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true);
  438. if (n_tokens < 0)
  439. {
  440. throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
  441. }
  442. return tokens.Take(n_tokens).ToList();
  443. }
  444. /// <summary>
  445. /// Detokenize a list of tokens.
  446. /// </summary>
  447. /// <param name="tokens">The list of tokens to detokenize.</param>
  448. /// <returns>The detokenized string.</returns>
  449. public string DeTokenize(IEnumerable<llama_token> tokens)
  450. {
  451. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  452. string output = "";
  453. foreach (var token in tokens)
  454. {
  455. output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
  456. }
  457. return output;
  458. }
  459. /// <summary>
  460. /// Call the model to run inference.
  461. /// </summary>
  462. /// <param name="text"></param>
  463. /// <param name="encoding"></param>
  464. /// <returns></returns>
  465. /// <exception cref="RuntimeError"></exception>
  466. public IEnumerable<string> Call(string text, string encoding = "UTF-8")
  467. {
  468. _is_antiprompt = false;
  469. if (_n_past > 0)
  470. {
  471. _is_interacting = false;
  472. }
  473. if (_is_interacting)
  474. {
  475. if (_verbose)
  476. {
  477. LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
  478. }
  479. _is_interacting = false;
  480. }
  481. ProcessTextBeforeInfer(text, encoding);
  482. while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
  483. {
  484. if (_embed.Count > 0)
  485. {
  486. // infinite text generation via context swapping
  487. // if we run out of context:
  488. // - take the n_keep first tokens from the original prompt (via n_past)
  489. // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
  490. if (_n_past + _embed.Count > _n_ctx)
  491. {
  492. int n_left = _n_past - _params.n_keep;
  493. _n_past = Math.Max(1, _params.n_keep);
  494. // insert n_left/2 tokens at the start of embed from last_n_tokens
  495. _embed.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embed.Count).Skip(_n_ctx - n_left / 2 - _embed.Count));
  496. // stop saving session if we run out of context
  497. _path_session = "";
  498. }
  499. // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
  500. // REVIEW
  501. if (_n_session_consumed < _session_tokens.Count)
  502. {
  503. int i = 0;
  504. for (; i < _embed.Count; i++)
  505. {
  506. if (_embed[i] != _session_tokens[_n_session_consumed])
  507. {
  508. _session_tokens = _session_tokens.Take(_n_session_consumed).ToList();
  509. break;
  510. }
  511. _n_past++;
  512. _n_session_consumed++;
  513. if (_n_session_consumed >= _session_tokens.Count)
  514. {
  515. i++;
  516. break;
  517. }
  518. }
  519. if (i > 0)
  520. {
  521. _embed.RemoveRange(0, i);
  522. }
  523. }
  524. // evaluate tokens in batches
  525. // embed is typically prepared beforehand to fit within a batch, but not always
  526. for (int i = 0; i < _embed.Count; i += _params.n_batch)
  527. {
  528. int n_eval = _embed.Count - i;
  529. if (n_eval > _params.n_batch)
  530. {
  531. n_eval = _params.n_batch;
  532. }
  533. var array = _embed.Skip(i).ToArray();
  534. if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
  535. {
  536. LLamaDefaultLogger.Default.Error($"Failed to eval.");
  537. throw new RuntimeError("Failed to eval.");
  538. }
  539. _n_past += n_eval;
  540. }
  541. if (_embed.Count > 0 && !string.IsNullOrEmpty(_path_session))
  542. {
  543. _session_tokens.AddRange(_embed);
  544. _n_session_consumed = _session_tokens.Count;
  545. }
  546. }
  547. _embed.Clear();
  548. if (_embed_inp.Count <= _n_consumed && !_is_interacting)
  549. {
  550. var temp = _params.temp;
  551. var top_k = _params.top_k <= 0 ? NativeApi.llama_n_vocab(_ctx) : _params.top_k;
  552. var top_p = _params.top_p;
  553. var tfs_z = _params.tfs_z;
  554. var typical_p = _params.typical_p;
  555. var repeat_last_n = _params.repeat_last_n < 0 ? _n_ctx : _params.repeat_last_n;
  556. var repeat_penalty = _params.repeat_penalty;
  557. var alpha_presence = _params.presence_penalty;
  558. var alpha_frequency = _params.frequency_penalty;
  559. var mirostat = _params.mirostat;
  560. var mirostat_tau = _params.mirostat_tau;
  561. var mirostat_eta = _params.mirostat_eta;
  562. var penalize_nl = _params.penalize_nl;
  563. // optionally save the session on first sample (for faster prompt loading next time)
  564. if (!string.IsNullOrEmpty(_path_session) && _need_to_save_session)
  565. {
  566. _need_to_save_session = false;
  567. NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count);
  568. }
  569. llama_token id = 0;
  570. {
  571. var n_vocab = NativeApi.llama_n_vocab(_ctx);
  572. var logits = Utils.llama_get_logits(_ctx, n_vocab);
  573. // Apply params.logit_bias map
  574. foreach (var (key, value) in _params.logit_bias)
  575. {
  576. logits[key] += value;
  577. }
  578. var candidates = new LLamaTokenData[n_vocab];
  579. for (llama_token token_id = 0; token_id < n_vocab; token_id++)
  580. candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
  581. LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
  582. // Apply penalties
  583. float nl_logit = logits[NativeApi.llama_token_nl()];
  584. var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
  585. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
  586. _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
  587. (ulong)last_n_repeat, repeat_penalty);
  588. SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
  589. _last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
  590. (ulong)last_n_repeat, alpha_frequency, alpha_presence);
  591. if (!penalize_nl)
  592. {
  593. logits[NativeApi.llama_token_nl()] = nl_logit;
  594. }
  595. if (temp <= 0)
  596. {
  597. // Greedy sampling
  598. id = SamplingApi.llama_sample_token_greedy(_ctx, candidates_p);
  599. }
  600. else
  601. {
  602. if (mirostat == 1)
  603. {
  604. float mirostat_mu = 2.0f * mirostat_tau;
  605. const int mirostat_m = 100;
  606. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  607. id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_m, ref mirostat_mu);
  608. }
  609. else if (mirostat == 2)
  610. {
  611. float mirostat_mu = 2.0f * mirostat_tau;
  612. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  613. id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates_p, mirostat_tau, mirostat_eta, ref mirostat_mu);
  614. }
  615. else
  616. {
  617. // Temperature sampling
  618. SamplingApi.llama_sample_top_k(_ctx, candidates_p, top_k, 1);
  619. SamplingApi.llama_sample_tail_free(_ctx, candidates_p, tfs_z, 1);
  620. SamplingApi.llama_sample_typical(_ctx, candidates_p, typical_p, 1);
  621. SamplingApi.llama_sample_top_p(_ctx, candidates_p, top_p, 1);
  622. SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
  623. id = SamplingApi.llama_sample_token(_ctx, candidates_p);
  624. }
  625. }
  626. _last_n_tokens.RemoveAt(0);
  627. _last_n_tokens.Add(id);
  628. }
  629. // replace end of text token with newline token when in interactive mode
  630. if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
  631. {
  632. id = _llama_token_newline[0];
  633. if (_params.antiprompt.Count != 0)
  634. {
  635. // tokenize and inject first reverse prompt
  636. var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false, encoding);
  637. _embed_inp.AddRange(first_antiprompt);
  638. }
  639. }
  640. // add it to the context
  641. _embed.Add(id);
  642. // echo this to console
  643. _input_echo = true;
  644. // decrement remaining sampling budget
  645. _n_remain--;
  646. }
  647. else
  648. {
  649. while (_embed_inp.Count > _n_consumed)
  650. {
  651. _embed.Add(_embed_inp[_n_consumed]);
  652. _last_n_tokens.RemoveAt(0);
  653. _last_n_tokens.Add(_embed_inp[_n_consumed]);
  654. _n_consumed++;
  655. if (_embed.Count >= _params.n_batch)
  656. {
  657. break;
  658. }
  659. }
  660. }
  661. if (_input_echo && !_is_interacting)
  662. {
  663. foreach (var id in _embed)
  664. {
  665. var res = Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
  666. yield return res;
  667. }
  668. }
  669. if (_params.interactive && _embed_inp.Count <= _n_consumed)
  670. {
  671. if (_params.antiprompt.Count > 0)
  672. {
  673. string last_output = "";
  674. foreach (var id in _last_n_tokens)
  675. {
  676. last_output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
  677. }
  678. _is_antiprompt = false;
  679. foreach (var antiprompt in _params.antiprompt)
  680. {
  681. if (last_output.EndsWith(antiprompt))
  682. {
  683. _is_interacting = true;
  684. _is_antiprompt = true;
  685. break;
  686. }
  687. }
  688. }
  689. if (_n_past > 0 && _is_interacting)
  690. {
  691. if (_params.instruct)
  692. {
  693. yield return "\n> ";
  694. }
  695. _input_echo = false;
  696. break;
  697. }
  698. if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
  699. {
  700. if (_params.instruct)
  701. {
  702. _is_interacting = true;
  703. }
  704. else
  705. {
  706. LLamaDefaultLogger.Default.Info(" [end of text]");
  707. }
  708. }
  709. if (_params.interactive && _n_remain <= 0 && _params.n_predict != -1)
  710. {
  711. _n_remain = _params.n_predict;
  712. _is_interacting = true;
  713. }
  714. }
  715. }
  716. if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all)
  717. {
  718. LLamaDefaultLogger.Default.Info($"saving final output to session file {_path_session}");
  719. var session_token_array = _session_tokens.ToArray();
  720. NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length);
  721. }
  722. }
  723. /// <inheritdoc />
  724. public void Dispose()
  725. {
  726. _ctx.Dispose();
  727. }
  728. }
  729. }