You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. using LLama.Exceptions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Configuration;
  6. using System.Diagnostics;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Runtime.CompilerServices;
  10. using System.Text;
  11. using LLama.Types;
  12. using System.Runtime.InteropServices;
  13. using System.Text.RegularExpressions;
  14. using System.Collections;
  15. namespace LLama
  16. {
  17. using llama_token = Int32;
  18. /// <summary>
  19. /// High-level Wrapper of a llama.cpp model for inference.
  20. /// </summary>
  21. [Obsolete]
  22. public class LLamaModelV1
  23. {
  24. private string _model_path;
  25. LLamaContextParams _params;
  26. private int _n_threads;
  27. private int _n_batch;
  28. private int _last_n_tokens_size;
  29. private string? _lora_base;
  30. private string? _lora_path;
  31. private bool _verbose;
  32. private Queue<llama_token> _eval_tokens;
  33. private Queue<float[]> _eval_logits;
  34. private LLamaCache? _cache;
  35. private SafeLLamaContextHandle _ctx;
  36. private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };
  37. /// <summary>
  38. /// Load a llama.cpp model from the path.
  39. /// </summary>
  40. /// <remarks>Note that the API is still unstable. The order of them is likely to
  41. /// be changed in the future. It's recommened to specify the parameter name when
  42. /// building your app. We use the cpp style parameter names here because it introduces
  43. /// convenience for searching the docs.</remarks>
  44. /// <param name="model_path">Path to the model.</param>
  45. /// <param name="n_ctx">Maximum context size.</param>
  46. /// <param name="n_parts">Number of parts to split the model into. If -1, the number of parts is automatically determined.</param>
  47. /// <param name="seed">Random seed. 0 for random.</param>
  48. /// <param name="f16_kv">Use half-precision for key/value cache.</param>
  49. /// <param name="logits_all">Return logits for all tokens, not just the last token.</param>
  50. /// <param name="vocab_only">Only load the vocabulary no weights.</param>
  51. /// <param name="use_mmap">Use mmap if possible.</param>
  52. /// <param name="use_mlock">Force the system to keep the model in RAM.</param>
  53. /// <param name="embedding">Embedding mode only.</param>
  54. /// <param name="n_threads">Number of threads to use. If is not specified, the number of threads is automatically determined.</param>
  55. /// <param name="n_batch">Maximum number of prompt tokens to batch together when calling llama_eval.</param>
  56. /// <param name="last_n_tokens_size">Maximum number of tokens to keep in the last_n_tokens deque.</param>
  57. /// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param>
  58. /// <param name="lora_path">Path to a LoRA file to apply to the model.</param>
  59. /// <param name="verbose">Print verbose output to stderr.</param>
  60. public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
  61. bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
  62. bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
  63. int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
  64. {
  65. _verbose = verbose;
  66. _model_path = model_path;
  67. _params = NativeApi.llama_context_default_params();
  68. _params.n_ctx = n_ctx;
  69. _params.n_parts = n_parts;
  70. _params.seed = seed;
  71. _params.f16_kv = f16_kv;
  72. _params.logits_all = logits_all;
  73. _params.vocab_only = vocab_only;
  74. _params.use_mmap = lora_path is null ? use_mmap : false;
  75. _params.use_mlock = use_mlock;
  76. _params.embedding = embedding;
  77. _last_n_tokens_size = last_n_tokens_size;
  78. _n_batch = Math.Min(n_ctx, n_batch);
  79. _eval_tokens = new Queue<int>(capacity: n_ctx);
  80. _eval_logits = new Queue<float[]>(logits_all ? n_ctx : 1);
  81. _cache = null;
  82. _n_threads = n_threads;
  83. if(_n_threads == -1)
  84. {
  85. _n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
  86. }
  87. _lora_base = lora_base;
  88. _lora_path = lora_path;
  89. if(!File.Exists(model_path) && !Directory.Exists(model_path))
  90. {
  91. throw new FileNotFoundException($"Model path does not exist: {model_path}");
  92. }
  93. // Move from heap to stack to prevent the moving.
  94. _ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));
  95. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  96. if(_lora_path is not null)
  97. {
  98. if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
  99. {
  100. throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
  101. }
  102. }
  103. if (_verbose)
  104. {
  105. Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info()));
  106. }
  107. }
  108. public LLamaModelV1(LLamaModelV1 other)
  109. {
  110. _ctx = other._ctx;
  111. _model_path = other._model_path;
  112. _params = other._params;
  113. _last_n_tokens_size = other._last_n_tokens_size;
  114. _n_threads = other._n_threads;
  115. _n_batch = other._n_batch;
  116. _verbose = other._verbose;
  117. _lora_base = other._lora_base;
  118. _lora_path = other._lora_path;
  119. _eval_logits = new Queue<float[]>(other._eval_logits);
  120. _eval_tokens = new Queue<llama_token>(other._eval_tokens);
  121. }
  122. /// <summary>
  123. /// Tokenize a string.
  124. /// </summary>
  125. /// <param name="text">The utf-8 encoded string to tokenize.</param>
  126. /// <returns>A list of tokens.</returns>
  127. /// <exception cref="RuntimeError">If the tokenization failed.</exception>
  128. public List<llama_token> Tokenize(string text)
  129. {
  130. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  131. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  132. var tokens = new llama_token[n_ctx];
  133. var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
  134. if(n_tokens < 0)
  135. {
  136. throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
  137. }
  138. return tokens.Take(n_tokens).ToList();
  139. }
  140. /// <summary>
  141. /// Detokenize a list of tokens.
  142. /// </summary>
  143. /// <param name="tokens">The list of tokens to detokenize.</param>
  144. /// <returns>The detokenized string.</returns>
  145. public string DeTokenize(IEnumerable<llama_token> tokens)
  146. {
  147. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  148. string output = "";
  149. foreach(var token in tokens)
  150. {
  151. output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
  152. }
  153. return output;
  154. }
  155. public string DeTokenize(llama_token token)
  156. {
  157. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  158. return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? "";
  159. }
  160. /// <summary>
  161. /// Set the cache.
  162. /// </summary>
  163. /// <param name="cache">The cache to set.</param>
  164. public void SetCache(LLamaCache? cache)
  165. {
  166. _cache = cache;
  167. }
  168. /// <summary>
  169. /// Reset the model state.
  170. /// </summary>
  171. public void Reset()
  172. {
  173. _eval_tokens.Clear();
  174. _eval_logits.Clear();
  175. }
  176. /// <summary>
  177. /// Evaluate a list of tokens.
  178. /// </summary>
  179. /// <param name="tokens">The list of tokens to evaluate.</param>
  180. /// <exception cref="RuntimeError"></exception>
  181. public unsafe void Eval(List<llama_token> tokens)
  182. {
  183. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  184. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  185. for(int i = 0; i < tokens.Count; i += _n_batch)
  186. {
  187. var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
  188. llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
  189. llama_token n_tokens = batch.Count();
  190. llama_token return_code = NativeApi.llama_eval(
  191. ctx: _ctx,
  192. tokens: batch.ToArray(),
  193. n_tokens: n_tokens,
  194. n_past: n_past,
  195. n_threads: _n_threads
  196. );
  197. if(return_code != 0)
  198. {
  199. throw new RuntimeError($"llama_eval returned {return_code}");
  200. }
  201. foreach(var b in batch)
  202. {
  203. _eval_tokens.Enqueue(b);
  204. }
  205. int rows = _params.logits_all ? n_tokens : 1;
  206. llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
  207. var cols = n_vocab;
  208. var logits_view = NativeApi.llama_get_logits(_ctx);
  209. for(int j = 0; j < rows; j++)
  210. {
  211. float[] logit = new float[cols];
  212. for(int k = 0; k < cols; k++)
  213. {
  214. logit[k] = logits_view[j * cols + k];
  215. }
  216. _eval_logits.Enqueue(logit);
  217. }
  218. }
  219. }
  220. private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
  221. float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
  222. {
  223. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  224. Debug.Assert(_eval_logits.Count > 0);
  225. llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
  226. var logits = _eval_logits.Last();
  227. LLamaTokenData[] data = new LLamaTokenData[n_vocab];
  228. for(int i = 0; i < n_vocab; i++)
  229. {
  230. data[i] = new LLamaTokenData(i, logits[i], .0f);
  231. }
  232. ulong size = (ulong)n_vocab;
  233. bool sorted = false;
  234. LLamaTokenDataArray candidates = new(data, size, sorted);
  235. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
  236. repeat_penalty);
  237. //SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
  238. // frequency_penalty, presence_penalty);
  239. if(temp == .0f)
  240. {
  241. return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
  242. }
  243. else
  244. {
  245. SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
  246. SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
  247. SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
  248. SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
  249. SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
  250. return SamplingApi.llama_sample_token(_ctx, candidates);
  251. }
  252. }
  253. /// <summary>
  254. /// Sample a token from the model.
  255. /// </summary>
  256. /// <param name="top_k">The top-k sampling parameter.</param>
  257. /// <param name="top_p">The top-p sampling parameter.</param>
  258. /// <param name="temp">The temperature parameter.</param>
  259. /// <param name="repeat_penalty">The repeat penalty parameter.</param>
  260. /// <param name="frequency_penalty"></param>
  261. /// <param name="presence_penalty"></param>
  262. /// <returns>The sampled token.</returns>
  263. public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
  264. float presence_penalty = .0f)
  265. {
  266. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  267. var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
  268. last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
  269. .Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
  270. llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
  271. int i = 0;
  272. foreach(var data in last_n_tokens_data)
  273. {
  274. if(i < _last_n_tokens_size)
  275. {
  276. tokens_data[i++] = data;
  277. }
  278. else
  279. {
  280. break;
  281. }
  282. }
  283. return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
  284. }
  285. /// <summary>
  286. /// Create a generator of tokens from a prompt.
  287. /// </summary>
  288. /// <example>
  289. /// Examples:
  290. /// var llama = new LlamaModel("models/ggml-7b.bin")
  291. /// var tokens = llama.Tokenize(b"Hello, world!")
  292. /// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
  293. /// Console.WriteLine(llama.DeTokenize(new []{token}));
  294. /// }
  295. /// </example>
  296. /// <param name="tokens"></param>
  297. /// <param name="top_k"></param>
  298. /// <param name="top_p"></param>
  299. /// <param name="temp"></param>
  300. /// <param name="repeat_penalty"></param>
  301. /// <param name="frequency_penalty"></param>
  302. /// <param name="presence_penalty"></param>
  303. /// <param name="reset"></param>
  304. /// <returns></returns>
  305. public IEnumerable<llama_token> Generate(IEnumerable<llama_token> tokens, int top_k, float top_p, float temp,
  306. float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
  307. {
  308. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  309. if(reset && _eval_tokens.Count > 0)
  310. {
  311. int longest_prefix = 0;
  312. foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
  313. {
  314. if(a == b)
  315. {
  316. longest_prefix += 1;
  317. }
  318. else
  319. {
  320. break;
  321. }
  322. }
  323. if(longest_prefix > 0)
  324. {
  325. if (_verbose)
  326. {
  327. Logger.Default.Info("Llama.generate: prefix-match hit");
  328. }
  329. reset = false;
  330. tokens = tokens.Skip(longest_prefix);
  331. for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
  332. {
  333. _eval_tokens.Dequeue();
  334. if(_eval_logits.Count > 0)
  335. {
  336. _eval_logits.Dequeue();
  337. }
  338. }
  339. }
  340. }
  341. if (reset)
  342. {
  343. Reset();
  344. }
  345. while (true)
  346. {
  347. Eval(tokens.ToList());
  348. var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
  349. yield return token;
  350. // TODO(Rinne): verify if the implementation is correct.
  351. }
  352. }
  353. /// <summary>
  354. /// Embed a string.
  355. /// </summary>
  356. /// <param name="input">The utf-8 encoded string to embed.</param>
  357. /// <returns>An embedding object.</returns>
  358. /// <exception cref="RuntimeError"></exception>
  359. public unsafe Embedding CreateEmbedding(string input)
  360. {
  361. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  362. if (!_params.embedding)
  363. {
  364. throw new RuntimeError("Llama model must be created with embedding=True to call this method");
  365. }
  366. if (_verbose)
  367. {
  368. NativeApi.llama_reset_timings(_ctx);
  369. }
  370. var tokens = Tokenize(input);
  371. Reset();
  372. Eval(tokens);
  373. int n_tokens = tokens.Count;
  374. var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
  375. int cnt = NativeApi.llama_n_embd(_ctx);
  376. float[] embedding = new float[cnt];
  377. for(int i = 0; i < cnt; i++)
  378. {
  379. embedding[i] = embeddingPtr[i];
  380. }
  381. if (_verbose)
  382. {
  383. NativeApi.llama_print_timings(_ctx);
  384. }
  385. return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
  386. new EmbeddingUsage(n_tokens, n_tokens));
  387. }
  388. public float[] Embed(string input)
  389. {
  390. return CreateEmbedding(input).Data[0].Embedding;
  391. }
  392. /// <summary>
  393. ///
  394. /// </summary>
  395. /// <param name="prompt"></param>
  396. /// <param name="suffix"></param>
  397. /// <param name="max_tokens"></param>
  398. /// <param name="temperature"></param>
  399. /// <param name="top_p"></param>
  400. /// <param name="logprobs"></param>
  401. /// <param name="echo"></param>
  402. /// <param name="stop"></param>
  403. /// <param name="frequency_penalty"></param>
  404. /// <param name="presence_penalty"></param>
  405. /// <param name="repeat_penalty"></param>
  406. /// <param name="top_k"></param>
  407. /// <returns>IEnumerable of Completion and CompletionChunk</returns>
  408. /// <exception cref="ArgumentException"></exception>
  409. private IEnumerable<CompletionChunk> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
  410. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  411. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  412. {
  413. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  414. string completionId = $"cmpl-{Guid.NewGuid()}";
  415. var created = DateTime.Now.Millisecond;
  416. List<llama_token> completionTokens = new List<llama_token>();
  417. var promptTokens = Tokenize($" {prompt}");
  418. string text = "";
  419. int returnedCharacters = 0;
  420. if(stop is null)
  421. {
  422. stop = new string[0];
  423. }
  424. if (_verbose)
  425. {
  426. NativeApi.llama_reset_timings(_ctx);
  427. }
  428. if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx))
  429. {
  430. throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}");
  431. }
  432. if(logprobs != -1 && !_params.logits_all)
  433. {
  434. throw new ArgumentException("logprobs is not supported for models created with logits_all=False");
  435. }
  436. if(_cache is not null)
  437. {
  438. try
  439. {
  440. // TODO(Rinne): revise it since it will compare reference instead of elements.
  441. var cacheItem = _cache[promptTokens.ToArray()];
  442. var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
  443. var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
  444. if(cachePrefixLen > evalPrefixLen)
  445. {
  446. LoadState(cacheItem);
  447. if (_verbose)
  448. {
  449. Logger.Default.Info("Llama._create_completion: cache hit");
  450. }
  451. }
  452. }
  453. catch (KeyNotFoundException)
  454. {
  455. if (_verbose)
  456. {
  457. Logger.Default.Warn("Llama._create_completion: cache miss");
  458. }
  459. }
  460. }
  461. string finishReason = "length";
  462. int multibyteFix = 0;
  463. bool reset = true;
  464. List<llama_token> tokens = new(promptTokens);
  465. if (reset && _eval_tokens.Count > 0)
  466. {
  467. int longest_prefix = 0;
  468. foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y)))
  469. {
  470. if (a == b)
  471. {
  472. longest_prefix += 1;
  473. }
  474. else
  475. {
  476. break;
  477. }
  478. }
  479. if (longest_prefix > 0)
  480. {
  481. if (_verbose)
  482. {
  483. Logger.Default.Info("Llama.generate: prefix-match hit");
  484. }
  485. reset = false;
  486. tokens = tokens.Skip(longest_prefix).ToList();
  487. for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
  488. {
  489. _eval_tokens.Dequeue();
  490. if (_eval_logits.Count > 0)
  491. {
  492. _eval_logits.Dequeue();
  493. }
  494. }
  495. }
  496. }
  497. if (reset)
  498. {
  499. Reset();
  500. }
  501. //foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
  502. string allText = "";
  503. while (true)
  504. {
  505. Eval(tokens);
  506. var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
  507. tokens.Clear();
  508. tokens.Add(token);
  509. if (token == NativeApi.llama_token_eos())
  510. {
  511. text = DeTokenize(completionTokens);
  512. finishReason = "stop";
  513. break;
  514. }
  515. completionTokens.Add(token);
  516. allText = DeTokenize(completionTokens);
  517. int cut = Math.Min(3, allText.Length);
  518. for(int i = allText.Length - cut; i < allText.Length; i++)
  519. {
  520. var c = (int)allText[i];
  521. int k = cut - i;
  522. foreach(var (num, pattern) in _numAndPatterns)
  523. {
  524. if(num > k && (pattern & c) == pattern)
  525. {
  526. multibyteFix = num - k;
  527. }
  528. }
  529. }
  530. if(multibyteFix > 0)
  531. {
  532. multibyteFix--;
  533. continue;
  534. }
  535. var anyStop = stop.Where(s => allText.Contains(s));
  536. if(anyStop.Count() > 0)
  537. {
  538. var firstStop = anyStop.First();
  539. text = allText.Substring(0, allText.IndexOf(firstStop));
  540. finishReason = "stop";
  541. break;
  542. }
  543. var start = returnedCharacters;
  544. int longest = 0;
  545. foreach (var s in stop)
  546. {
  547. for (int i = s.Length; i > 0; i--)
  548. {
  549. if (allText.EndsWith(s.Substring(0, i)))
  550. {
  551. if (i > longest)
  552. {
  553. longest = i;
  554. }
  555. break;
  556. }
  557. }
  558. }
  559. text = allText.Substring(0, allText.Length - longest);
  560. returnedCharacters += text.Skip(start).Count();
  561. yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
  562. {
  563. new CompletionChoice(text.Substring(start), 0, null, finishReason)
  564. });
  565. }
  566. if (_cache is not null)
  567. {
  568. if (_verbose)
  569. {
  570. Logger.Default.Info("Llama._create_completion: cache save");
  571. }
  572. _cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
  573. }
  574. string textStr = text;
  575. if (echo)
  576. {
  577. textStr = prompt + textStr;
  578. }
  579. if(suffix is not null)
  580. {
  581. textStr = textStr + suffix;
  582. }
  583. CompletionLogprobs? logProbs = null;
  584. if (logprobs != -1)
  585. {
  586. int textOffset = 0;
  587. List<int> textOffsets = new();
  588. List<float> tokenLogprobs = new();
  589. List<string> tokenStrs = new();
  590. List<Dictionary<string, float>> topLogprobs = new();
  591. var allTokens = promptTokens.Concat(completionTokens).ToArray();
  592. var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t }));
  593. var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row));
  594. foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y))
  595. .Zip(allLogProbs, (x, y) => (x.x, x.y, y)))
  596. {
  597. textOffsets.Add(textOffset);
  598. textOffset += tokenStr.Length;
  599. tokenStrs.Add(tokenStr);
  600. var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y))
  601. .OrderByDescending(x => x.x).ToList();
  602. tokenLogprobs.Add(sortedLogprobs[token].x);
  603. var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x);
  604. topLogprob[tokenStr] = sortedLogprobs[token].x;
  605. topLogprobs.Add(topLogprob);
  606. }
  607. logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray());
  608. }
  609. if (_verbose)
  610. {
  611. NativeApi.llama_print_timings(_ctx);
  612. }
  613. }
  614. /// <summary>
  615. /// Generate text from a prompt and yield return the result.
  616. /// </summary>
  617. /// <param name="prompt">The prompt to generate text from.</param>
  618. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  619. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  620. /// <param name="temperature">The temperature to use for sampling.</param>
  621. /// <param name="top_p">The top-p value to use for sampling.</param>
  622. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  623. /// <param name="echo">Whether to echo the prompt.</param>
  624. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  625. /// <param name="frequency_penalty"></param>
  626. /// <param name="presence_penalty"></param>
  627. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  628. /// <param name="top_k">The top-k value to use for sampling.</param>
  629. /// <returns></returns>
  630. public IEnumerable<CompletionChunk> CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  631. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  632. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  633. {
  634. return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  635. frequency_penalty, presence_penalty, repeat_penalty, top_k);
  636. }
  637. /// <summary>
  638. /// Generate text from a prompt and yield return the result.
  639. /// </summary>
  640. /// <param name="prompt">The prompt to generate text from.</param>
  641. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  642. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  643. /// <param name="temperature">The temperature to use for sampling.</param>
  644. /// <param name="top_p">The top-p value to use for sampling.</param>
  645. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  646. /// <param name="echo">Whether to echo the prompt.</param>
  647. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  648. /// <param name="frequency_penalty"></param>
  649. /// <param name="presence_penalty"></param>
  650. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  651. /// <param name="top_k">The top-k value to use for sampling.</param>
  652. /// <returns></returns>
  653. public IEnumerable<CompletionChunk> Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  654. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  655. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  656. {
  657. return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  658. frequency_penalty, presence_penalty, repeat_penalty, top_k);
  659. }
  660. private ChatCompletion ConvertTextCompletionToChat(Completion completion)
  661. {
  662. return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
  663. new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text),
  664. completion.Choices[0].FinishReason) }, completion.Usage);
  665. }
  666. private IEnumerable<ChatCompletionChunk> ConvertTextCompletionChunksToChat(IEnumerable<CompletionChunk> chunks)
  667. {
  668. bool isFirst = true;
  669. foreach(var chunk in chunks)
  670. {
  671. if(isFirst)
  672. {
  673. yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
  674. new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) });
  675. isFirst = false;
  676. }
  677. yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
  678. new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text),
  679. chunk.Choices[0].FinishReason) });
  680. }
  681. }
  682. /// <summary>
  683. /// Generate a chat completion from a list of messages and yield return the result.
  684. /// </summary>
  685. /// <param name="messages">A list of messages to generate a response for.</param>
  686. /// <param name="temperature">The temperature to use for sampling.</param>
  687. /// <param name="top_p">The top-p value to use for sampling.</param>
  688. /// <param name="top_k">The top-k value to use for sampling.</param>
  689. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  690. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  691. /// <param name="presence_penalty"></param>
  692. /// <param name="frequency_penalty"></param>
  693. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  694. /// <returns></returns>
  695. public IEnumerable<ChatCompletionChunk> CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
  696. int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
  697. float repeat_penalty = 1.1f)
  698. {
  699. if (stop is null)
  700. {
  701. stop = new string[0];
  702. }
  703. string GetRole(ChatCompletionMessage message)
  704. {
  705. return message.Role == ChatRole.Human ? "Human" : "Assistant";
  706. }
  707. string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
  708. var prompt = chatHistory + "### Assistant:";
  709. prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens));
  710. var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
  711. var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
  712. repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
  713. return ConvertTextCompletionChunksToChat(completion);
  714. }
  715. public LLamaState SaveState()
  716. {
  717. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  718. ulong stateSize = NativeApi.llama_get_state_size(_ctx);
  719. byte[] llamaState = new byte[stateSize];
  720. ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState);
  721. if(nBytes > stateSize)
  722. {
  723. throw new RuntimeError("Failed to copy llama state data");
  724. }
  725. byte[] llamaStateCompact = new byte[nBytes];
  726. llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0);
  727. if (_verbose)
  728. {
  729. Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state");
  730. }
  731. return new LLamaState(new Queue<llama_token>(_eval_tokens), new Queue<float[]>(_eval_logits),
  732. llamaStateCompact, (int)nBytes);
  733. }
  734. public void LoadState(LLamaState state)
  735. {
  736. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  737. _eval_tokens = new Queue<llama_token>(state.EvalTokens);
  738. _eval_logits = new Queue<float[]>(state.EvalLogits);
  739. if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size)
  740. {
  741. throw new RuntimeError($"Failed to set llama state data");
  742. }
  743. }
  744. private static IEnumerable<float> LogitsToLogprobs(IEnumerable<float> logits)
  745. {
  746. var exps = logits.Select(x => (float)Math.Exp(x));
  747. var sumExps = exps.Sum();
  748. return exps.Select(x => (float)Math.Log(x / sumExps));
  749. }
  750. internal static int LongestTokenPrefix(IEnumerable<llama_token> a, IEnumerable<llama_token> b)
  751. {
  752. int longestPrefix = 0;
  753. foreach(var (x, y) in a.Zip(b, (x, y) => (x, y)))
  754. {
  755. if(x == y)
  756. {
  757. longestPrefix++;
  758. }
  759. else
  760. {
  761. break;
  762. }
  763. }
  764. return longestPrefix;
  765. }
  766. }
  767. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。