You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 42 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. using LLama.Exceptions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Configuration;
  6. using System.Diagnostics;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Runtime.CompilerServices;
  10. using System.Text;
  11. using LLama.Types;
  12. using System.Runtime.InteropServices;
  13. using System.Text.RegularExpressions;
  14. namespace LLama
  15. {
  16. using llama_token = Int32;
  17. /// <summary>
  18. /// High-level Wrapper of a llama.cpp model for inference.
  19. /// </summary>
  20. public class LLamaModel
  21. {
  22. private string _model_path;
  23. LLamaContextParams _params;
  24. private int _n_threads;
  25. private int _n_batch;
  26. private int _last_n_tokens_size;
  27. private string? _lora_base;
  28. private string? _lora_path;
  29. private bool _verbose;
  30. private Queue<llama_token> _eval_tokens;
  31. private Queue<float[]> _eval_logits;
  32. private LLamaCache? _cache;
  33. private SafeLLamaContextHandle _ctx;
  34. private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };
  35. /// <summary>
  36. /// Load a llama.cpp model from the path.
  37. /// </summary>
  38. /// <remarks>Note that the API is still unstable. The order of them is likely to
  39. /// be changed in the future. It's recommened to specify the parameter name when
  40. /// building your app. We use the cpp style parameter names here because it introduces
  41. /// convenience for searching the docs.</remarks>
  42. /// <param name="model_path">Path to the model.</param>
  43. /// <param name="n_ctx">Maximum context size.</param>
  44. /// <param name="n_parts">Number of parts to split the model into. If -1, the number of parts is automatically determined.</param>
  45. /// <param name="seed">Random seed. 0 for random.</param>
  46. /// <param name="f16_kv">Use half-precision for key/value cache.</param>
  47. /// <param name="logits_all">Return logits for all tokens, not just the last token.</param>
  48. /// <param name="vocab_only">Only load the vocabulary no weights.</param>
  49. /// <param name="use_mmap">Use mmap if possible.</param>
  50. /// <param name="use_mlock">Force the system to keep the model in RAM.</param>
  51. /// <param name="embedding">Embedding mode only.</param>
  52. /// <param name="n_threads">Number of threads to use. If is not specified, the number of threads is automatically determined.</param>
  53. /// <param name="n_batch">Maximum number of prompt tokens to batch together when calling llama_eval.</param>
  54. /// <param name="last_n_tokens_size">Maximum number of tokens to keep in the last_n_tokens deque.</param>
  55. /// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param>
  56. /// <param name="lora_path">Path to a LoRA file to apply to the model.</param>
  57. /// <param name="verbose">Print verbose output to stderr.</param>
  58. public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
  59. bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
  60. bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
  61. int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
  62. {
  63. _verbose = verbose;
  64. _model_path = model_path;
  65. _params = NativeApi.llama_context_default_params();
  66. _params.n_ctx = n_ctx;
  67. _params.n_parts = n_parts;
  68. _params.seed = seed;
  69. _params.f16_kv = f16_kv;
  70. _params.logits_all = logits_all;
  71. _params.vocab_only = vocab_only;
  72. _params.use_mmap = lora_path is null ? use_mmap : false;
  73. _params.use_mlock = use_mlock;
  74. _params.embedding = embedding;
  75. _last_n_tokens_size = last_n_tokens_size;
  76. _n_batch = Math.Min(n_ctx, n_batch);
  77. _eval_tokens = new Queue<int>(capacity: n_ctx);
  78. _eval_logits = new Queue<float[]>(logits_all ? n_ctx : 1);
  79. _cache = null;
  80. _n_threads = n_threads;
  81. if(_n_threads == -1)
  82. {
  83. _n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
  84. }
  85. _lora_base = lora_base;
  86. _lora_path = lora_path;
  87. if(!File.Exists(model_path) && !Directory.Exists(model_path))
  88. {
  89. throw new FileNotFoundException($"Model path does not exist: {model_path}");
  90. }
  91. // Move from heap to stack to prevent the moving.
  92. _ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));
  93. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  94. if(_lora_path is not null)
  95. {
  96. if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
  97. {
  98. throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
  99. }
  100. }
  101. if (_verbose)
  102. {
  103. #if NET6_0_OR_GREATER
  104. Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info()));
  105. #endif
  106. }
  107. }
  108. public LLamaModel(LLamaModel other)
  109. {
  110. _ctx = other._ctx;
  111. _model_path = other._model_path;
  112. _params = other._params;
  113. _last_n_tokens_size = other._last_n_tokens_size;
  114. _n_threads = other._n_threads;
  115. _n_batch = other._n_batch;
  116. _verbose = other._verbose;
  117. _lora_base = other._lora_base;
  118. _lora_path = other._lora_path;
  119. _eval_logits = new Queue<float[]>(other._eval_logits);
  120. _eval_tokens = new Queue<llama_token>(other._eval_tokens);
  121. }
  122. /// <summary>
  123. /// Tokenize a string.
  124. /// </summary>
  125. /// <param name="text">The utf-8 encoded string to tokenize.</param>
  126. /// <returns>A list of tokens.</returns>
  127. /// <exception cref="RuntimeError">If the tokenization failed.</exception>
  128. public List<llama_token> Tokenize(string text)
  129. {
  130. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  131. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  132. var tokens = new llama_token[n_ctx];
  133. var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
  134. if(n_tokens < 0)
  135. {
  136. throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
  137. }
  138. return tokens.Take(n_tokens).ToList();
  139. }
  140. /// <summary>
  141. /// Detokenize a list of tokens.
  142. /// </summary>
  143. /// <param name="tokens">The list of tokens to detokenize.</param>
  144. /// <returns>The detokenized string.</returns>
  145. public string DeTokenize(IEnumerable<llama_token> tokens)
  146. {
  147. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  148. string output = "";
  149. foreach(var token in tokens)
  150. {
  151. #if NET6_0_OR_GREATER
  152. output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
  153. #else
  154. output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token));
  155. #endif
  156. }
  157. return output;
  158. }
  159. /// <summary>
  160. /// Set the cache.
  161. /// </summary>
  162. /// <param name="cache">The cache to set.</param>
  163. public void SetCache(LLamaCache? cache)
  164. {
  165. _cache = cache;
  166. }
  167. /// <summary>
  168. /// Reset the model state.
  169. /// </summary>
  170. public void Reset()
  171. {
  172. _eval_tokens.Clear();
  173. _eval_logits.Clear();
  174. }
  175. /// <summary>
  176. /// Evaluate a list of tokens.
  177. /// </summary>
  178. /// <param name="tokens">The list of tokens to evaluate.</param>
  179. /// <exception cref="RuntimeError"></exception>
  180. public unsafe void Eval(List<llama_token> tokens)
  181. {
  182. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  183. var n_ctx = NativeApi.llama_n_ctx(_ctx);
  184. for(int i = 0; i < tokens.Count; i += _n_batch)
  185. {
  186. var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
  187. llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
  188. llama_token n_tokens = batch.Count();
  189. llama_token return_code = NativeApi.llama_eval(
  190. ctx: _ctx,
  191. tokens: batch.ToArray(),
  192. n_tokens: n_tokens,
  193. n_past: n_past,
  194. n_threads: _n_threads
  195. );
  196. if(return_code != 0)
  197. {
  198. throw new RuntimeError($"llama_eval returned {return_code}");
  199. }
  200. foreach(var b in batch)
  201. {
  202. _eval_tokens.Enqueue(b);
  203. }
  204. int rows = _params.logits_all ? n_tokens : 1;
  205. llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
  206. var cols = n_vocab;
  207. var logits_view = NativeApi.llama_get_logits(_ctx);
  208. for(int j = 0; j < rows; j++)
  209. {
  210. float[] logit = new float[cols];
  211. for(int k = 0; k < cols; k++)
  212. {
  213. logit[k] = logits_view[j * cols + k];
  214. }
  215. _eval_logits.Enqueue(logit);
  216. }
  217. }
  218. }
  219. private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
  220. float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
  221. {
  222. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  223. Debug.Assert(_eval_logits.Count > 0);
  224. llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
  225. var logits = _eval_logits.Last();
  226. LLamaTokenData[] data = new LLamaTokenData[n_vocab];
  227. for(int i = 0; i < n_vocab; i++)
  228. {
  229. data[i] = new LLamaTokenData(i, logits[i], .0f);
  230. }
  231. ulong size = (ulong)n_vocab;
  232. bool sorted = false;
  233. LLamaTokenDataArray candidates = new(data, size, sorted);
  234. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
  235. repeat_penalty);
  236. //SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
  237. // frequency_penalty, presence_penalty);
  238. if(temp == .0f)
  239. {
  240. return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
  241. }
  242. else
  243. {
  244. SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
  245. SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
  246. SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
  247. SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
  248. SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
  249. return SamplingApi.llama_sample_token(_ctx, candidates);
  250. }
  251. }
  252. /// <summary>
  253. /// Sample a token from the model.
  254. /// </summary>
  255. /// <param name="top_k">The top-k sampling parameter.</param>
  256. /// <param name="top_p">The top-p sampling parameter.</param>
  257. /// <param name="temp">The temperature parameter.</param>
  258. /// <param name="repeat_penalty">The repeat penalty parameter.</param>
  259. /// <param name="frequency_penalty"></param>
  260. /// <param name="presence_penalty"></param>
  261. /// <returns>The sampled token.</returns>
  262. public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
  263. float presence_penalty = .0f)
  264. {
  265. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  266. var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
  267. last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
  268. .Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
  269. llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
  270. int i = 0;
  271. foreach(var data in last_n_tokens_data)
  272. {
  273. if(i < _last_n_tokens_size)
  274. {
  275. tokens_data[i++] = data;
  276. }
  277. else
  278. {
  279. break;
  280. }
  281. }
  282. return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
  283. }
  284. /// <summary>
  285. /// Create a generator of tokens from a prompt.
  286. /// </summary>
  287. /// <example>
  288. /// Examples:
  289. /// var llama = new LlamaModel("models/ggml-7b.bin")
  290. /// var tokens = llama.Tokenize(b"Hello, world!")
  291. /// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
  292. /// Console.WriteLine(llama.DeTokenize(new []{token}));
  293. /// }
  294. /// </example>
  295. /// <param name="tokens"></param>
  296. /// <param name="top_k"></param>
  297. /// <param name="top_p"></param>
  298. /// <param name="temp"></param>
  299. /// <param name="repeat_penalty"></param>
  300. /// <param name="frequency_penalty"></param>
  301. /// <param name="presence_penalty"></param>
  302. /// <param name="reset"></param>
  303. /// <returns></returns>
  304. public IEnumerable<llama_token> Generate(IEnumerable<llama_token> tokens, int top_k, float top_p, float temp,
  305. float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
  306. {
  307. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  308. if(reset && _eval_tokens.Count > 0)
  309. {
  310. int longest_prefix = 0;
  311. foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
  312. {
  313. if(a == b)
  314. {
  315. longest_prefix += 1;
  316. }
  317. else
  318. {
  319. break;
  320. }
  321. }
  322. if(longest_prefix > 0)
  323. {
  324. if (_verbose)
  325. {
  326. Logger.Default.Info("Llama.generate: prefix-match hit");
  327. }
  328. reset = false;
  329. tokens = tokens.Skip(longest_prefix);
  330. for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
  331. {
  332. _eval_tokens.Dequeue();
  333. if(_eval_logits.Count > 0)
  334. {
  335. _eval_logits.Dequeue();
  336. }
  337. }
  338. }
  339. }
  340. if (reset)
  341. {
  342. Reset();
  343. }
  344. while (true)
  345. {
  346. Eval(tokens.ToList());
  347. var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
  348. yield return token;
  349. // TODO(Rinne): verify if the implementation is correct.
  350. }
  351. }
  352. /// <summary>
  353. /// Embed a string.
  354. /// </summary>
  355. /// <param name="input">The utf-8 encoded string to embed.</param>
  356. /// <returns>An embedding object.</returns>
  357. /// <exception cref="RuntimeError"></exception>
  358. public unsafe Embedding CreateEmbedding(string input)
  359. {
  360. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  361. if (!_params.embedding)
  362. {
  363. throw new RuntimeError("Llama model must be created with embedding=True to call this method");
  364. }
  365. if (_verbose)
  366. {
  367. NativeApi.llama_reset_timings(_ctx);
  368. }
  369. var tokens = Tokenize(input);
  370. Reset();
  371. Eval(tokens);
  372. int n_tokens = tokens.Count;
  373. var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
  374. int cnt = NativeApi.llama_n_embd(_ctx);
  375. float[] embedding = new float[cnt];
  376. for(int i = 0; i < cnt; i++)
  377. {
  378. embedding[i] = embeddingPtr[i];
  379. }
  380. if (_verbose)
  381. {
  382. NativeApi.llama_print_timings(_ctx);
  383. }
  384. return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
  385. new EmbeddingUsage(n_tokens, n_tokens));
  386. }
  387. public float[] Embed(string input)
  388. {
  389. return CreateEmbedding(input).Data[0].Embedding;
  390. }
  391. /// <summary>
  392. ///
  393. /// </summary>
  394. /// <param name="prompt"></param>
  395. /// <param name="suffix"></param>
  396. /// <param name="max_tokens"></param>
  397. /// <param name="temperature"></param>
  398. /// <param name="top_p"></param>
  399. /// <param name="logprobs"></param>
  400. /// <param name="echo"></param>
  401. /// <param name="stop"></param>
  402. /// <param name="frequency_penalty"></param>
  403. /// <param name="presence_penalty"></param>
  404. /// <param name="repeat_penalty"></param>
  405. /// <param name="top_k"></param>
  406. /// <param name="stream"></param>
  407. /// <returns>IEnumerable of Completion and CompletionChunk</returns>
  408. /// <exception cref="ArgumentException"></exception>
  409. private IEnumerable<object> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
  410. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  411. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40, bool stream = false)
  412. {
  413. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  414. string completionId = $"cmpl-{Guid.NewGuid()}";
  415. var created = DateTime.Now.Millisecond;
  416. List<llama_token> completionTokens = new List<llama_token>();
  417. var promptTokens = Tokenize($" {prompt}");
  418. string text = "";
  419. int returnedCharacters = 0;
  420. if(stop is null)
  421. {
  422. stop = new string[0];
  423. }
  424. if (_verbose)
  425. {
  426. NativeApi.llama_reset_timings(_ctx);
  427. }
  428. if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx))
  429. {
  430. throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}");
  431. }
  432. if(logprobs != -1 && !_params.logits_all)
  433. {
  434. throw new ArgumentException("logprobs is not supported for models created with logits_all=False");
  435. }
  436. if(_cache is not null)
  437. {
  438. try
  439. {
  440. // TODO(Rinne): revise it since it will compare reference instead of elements.
  441. var cacheItem = _cache[promptTokens.ToArray()];
  442. var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
  443. var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
  444. if(cachePrefixLen > evalPrefixLen)
  445. {
  446. LoadState(cacheItem);
  447. if (_verbose)
  448. {
  449. Logger.Default.Info("Llama._create_completion: cache hit");
  450. }
  451. }
  452. }
  453. catch (KeyNotFoundException)
  454. {
  455. if (_verbose)
  456. {
  457. Logger.Default.Warn("Llama._create_completion: cache miss");
  458. }
  459. }
  460. }
  461. string finishReason = "length";
  462. int multibyteFix = 0;
  463. bool reset = true;
  464. List<llama_token> tokens = new(promptTokens);
  465. if (reset && _eval_tokens.Count > 0)
  466. {
  467. int longest_prefix = 0;
  468. foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y)))
  469. {
  470. if (a == b)
  471. {
  472. longest_prefix += 1;
  473. }
  474. else
  475. {
  476. break;
  477. }
  478. }
  479. if (longest_prefix > 0)
  480. {
  481. if (_verbose)
  482. {
  483. Logger.Default.Info("Llama.generate: prefix-match hit");
  484. }
  485. reset = false;
  486. tokens = tokens.Skip(longest_prefix).ToList();
  487. for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
  488. {
  489. _eval_tokens.Dequeue();
  490. if (_eval_logits.Count > 0)
  491. {
  492. _eval_logits.Dequeue();
  493. }
  494. }
  495. }
  496. }
  497. if (reset)
  498. {
  499. Reset();
  500. }
  501. //foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
  502. while(true)
  503. {
  504. Eval(tokens);
  505. var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
  506. tokens.Clear();
  507. tokens.Add(token);
  508. if (token == NativeApi.llama_token_eos())
  509. {
  510. text = DeTokenize(completionTokens);
  511. finishReason = "stop";
  512. break;
  513. }
  514. completionTokens.Add(token);
  515. string allText = DeTokenize(completionTokens);
  516. int cut = Math.Min(3, allText.Length);
  517. for(int i = allText.Length - cut; i < allText.Length; i++)
  518. {
  519. var c = (int)allText[i];
  520. int k = cut - i;
  521. foreach(var (num, pattern) in _numAndPatterns)
  522. {
  523. if(num > k && (pattern & c) == pattern)
  524. {
  525. multibyteFix = num - k;
  526. }
  527. }
  528. }
  529. if(multibyteFix > 0)
  530. {
  531. multibyteFix--;
  532. continue;
  533. }
  534. var anyStop = stop.Where(s => allText.Contains(s));
  535. if(anyStop.Count() > 0)
  536. {
  537. var firstStop = anyStop.First();
  538. text = allText.Substring(0, allText.IndexOf(firstStop));
  539. finishReason = "stop";
  540. break;
  541. }
  542. if (stream)
  543. {
  544. var start = returnedCharacters;
  545. int longest = 0;
  546. foreach(var s in stop)
  547. {
  548. for(int i = s.Length; i > 0; i--)
  549. {
  550. if(allText.EndsWith(s.Substring(0, i)))
  551. {
  552. if(i > longest)
  553. {
  554. longest = i;
  555. }
  556. break;
  557. }
  558. }
  559. }
  560. text = allText.Substring(0, allText.Length - longest);
  561. returnedCharacters += text.Skip(start).Count();
  562. yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
  563. {
  564. new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
  565. });
  566. }
  567. }
  568. if(_cache is not null)
  569. {
  570. if (_verbose)
  571. {
  572. Logger.Default.Info("Llama._create_completion: cache save");
  573. }
  574. _cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
  575. }
  576. if (stream)
  577. {
  578. yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
  579. {
  580. new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
  581. });
  582. }
  583. string textStr = text;
  584. if (echo)
  585. {
  586. textStr = prompt + textStr;
  587. }
  588. if(suffix is not null)
  589. {
  590. textStr = textStr + suffix;
  591. }
  592. CompletionLogprobs? logProbs = null;
  593. if (logprobs != -1)
  594. {
  595. int textOffset = 0;
  596. List<int> textOffsets = new();
  597. List<float> tokenLogprobs = new();
  598. List<string> tokenStrs = new();
  599. List<Dictionary<string, float>> topLogprobs = new();
  600. var allTokens = promptTokens.Concat(completionTokens).ToArray();
  601. var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t }));
  602. var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row));
  603. foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y))
  604. .Zip(allLogProbs, (x, y) => (x.x, x.y, y)))
  605. {
  606. textOffsets.Add(textOffset);
  607. textOffset += tokenStr.Length;
  608. tokenStrs.Add(tokenStr);
  609. var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y))
  610. .OrderByDescending(x => x.x).ToList();
  611. tokenLogprobs.Add(sortedLogprobs[token].x);
  612. var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x);
  613. topLogprob[tokenStr] = sortedLogprobs[token].x;
  614. topLogprobs.Add(topLogprob);
  615. }
  616. logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray());
  617. }
  618. if (_verbose)
  619. {
  620. NativeApi.llama_print_timings(_ctx);
  621. }
  622. yield return new Completion(completionId, "text_completion", created, _model_path, new CompletionChoice[]
  623. {
  624. new CompletionChoice(text, 0, logProbs, finishReason)
  625. }, new CompletionUsage(promptTokens.Count, completionTokens.Count, promptTokens.Count + completionTokens.Count));
  626. }
  627. /// <summary>
  628. /// Generate text from a prompt and yield return the result.
  629. /// </summary>
  630. /// <param name="prompt">The prompt to generate text from.</param>
  631. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  632. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  633. /// <param name="temperature">The temperature to use for sampling.</param>
  634. /// <param name="top_p">The top-p value to use for sampling.</param>
  635. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  636. /// <param name="echo">Whether to echo the prompt.</param>
  637. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  638. /// <param name="frequency_penalty"></param>
  639. /// <param name="presence_penalty"></param>
  640. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  641. /// <param name="top_k">The top-k value to use for sampling.</param>
  642. /// <returns></returns>
  643. public IEnumerable<CompletionChunk> CreateCompletionStream(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  644. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  645. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  646. {
  647. yield return (CompletionChunk)CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  648. frequency_penalty, presence_penalty, repeat_penalty, top_k, true);
  649. }
  650. /// <summary>
  651. /// Generate text from a prompt.
  652. /// </summary>
  653. /// <param name="prompt">The prompt to generate text from.</param>
  654. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  655. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  656. /// <param name="temperature">The temperature to use for sampling.</param>
  657. /// <param name="top_p">The top-p value to use for sampling.</param>
  658. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  659. /// <param name="echo">Whether to echo the prompt.</param>
  660. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  661. /// <param name="frequency_penalty"></param>
  662. /// <param name="presence_penalty"></param>
  663. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  664. /// <param name="top_k">The top-k value to use for sampling.</param>
  665. /// <returns></returns>
  666. public Completion CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  667. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  668. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  669. {
  670. var completion = CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  671. frequency_penalty, presence_penalty, repeat_penalty, top_k, false).First();
  672. return (Completion)completion;
  673. }
  674. /// <summary>
  675. /// Generate text from a prompt.
  676. /// </summary>
  677. /// <param name="prompt">The prompt to generate text from.</param>
  678. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  679. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  680. /// <param name="temperature">The temperature to use for sampling.</param>
  681. /// <param name="top_p">The top-p value to use for sampling.</param>
  682. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  683. /// <param name="echo">Whether to echo the prompt.</param>
  684. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  685. /// <param name="frequency_penalty"></param>
  686. /// <param name="presence_penalty"></param>
  687. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  688. /// <param name="top_k">The top-k value to use for sampling.</param>
  689. /// <returns></returns>
  690. public Completion Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  691. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  692. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  693. {
  694. return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  695. frequency_penalty, presence_penalty, repeat_penalty, top_k);
  696. }
  697. /// <summary>
  698. /// Generate text from a prompt and yield return the result.
  699. /// </summary>
  700. /// <param name="prompt">The prompt to generate text from.</param>
  701. /// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
  702. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  703. /// <param name="temperature">The temperature to use for sampling.</param>
  704. /// <param name="top_p">The top-p value to use for sampling.</param>
  705. /// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
  706. /// <param name="echo">Whether to echo the prompt.</param>
  707. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  708. /// <param name="frequency_penalty"></param>
  709. /// <param name="presence_penalty"></param>
  710. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  711. /// <param name="top_k">The top-k value to use for sampling.</param>
  712. /// <returns></returns>
  713. public IEnumerable<CompletionChunk> StreamCall(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
  714. float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
  715. float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
  716. {
  717. return CreateCompletionStream(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
  718. frequency_penalty, presence_penalty, repeat_penalty, top_k);
  719. }
  720. private ChatCompletion ConvertTextCompletionToChat(Completion completion)
  721. {
  722. return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
  723. new[] { new ChatCompletionChoice(0, new ChatCompletionMessage("assistant", completion.Choices[0].Text, null),
  724. completion.Choices[0].FinishReason) }, completion.Usage);
  725. }
  726. private IEnumerable<ChatCompletionChunk> ConvertTextCompletionChunksToChat(IEnumerable<CompletionChunk> chunks)
  727. {
  728. bool isFirst = true;
  729. foreach(var chunk in chunks)
  730. {
  731. if(isFirst)
  732. {
  733. yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
  734. new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) });
  735. isFirst = false;
  736. }
  737. yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
  738. new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text),
  739. chunk.Choices[0].FinishReason) });
  740. }
  741. }
  742. /// <summary>
  743. /// Generate a chat completion from a list of messages.
  744. /// </summary>
  745. /// <param name="messages">A list of messages to generate a response for.</param>
  746. /// <param name="temperature">The temperature to use for sampling.</param>
  747. /// <param name="top_p">The top-p value to use for sampling.</param>
  748. /// <param name="top_k">The top-k value to use for sampling.</param>
  749. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  750. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  751. /// <param name="presence_penalty"></param>
  752. /// <param name="frequency_penalty"></param>
  753. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  754. /// <returns></returns>
  755. public ChatCompletion CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
  756. int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
  757. float repeat_penalty = 1.1f)
  758. {
  759. if(stop is null)
  760. {
  761. stop = new string[0];
  762. }
  763. string GetRole(ChatCompletionMessage message)
  764. {
  765. return message.Role == "user" ? "Human" : "Assistant";
  766. }
  767. string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
  768. var prompt = chatHistory + "### Assistant:";
  769. var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
  770. var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
  771. repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
  772. return ConvertTextCompletionToChat(completion);
  773. }
  774. /// <summary>
  775. /// Generate a chat completion from a list of messages and yield return the result.
  776. /// </summary>
  777. /// <param name="messages">A list of messages to generate a response for.</param>
  778. /// <param name="temperature">The temperature to use for sampling.</param>
  779. /// <param name="top_p">The top-p value to use for sampling.</param>
  780. /// <param name="top_k">The top-k value to use for sampling.</param>
  781. /// <param name="stop">A list of strings to stop generation when encountered.</param>
  782. /// <param name="max_tokens">The maximum number of tokens to generate.</param>
  783. /// <param name="presence_penalty"></param>
  784. /// <param name="frequency_penalty"></param>
  785. /// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
  786. /// <returns></returns>
  787. public IEnumerable<ChatCompletionChunk> CreateChatCompletionStream(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
  788. int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
  789. float repeat_penalty = 1.1f)
  790. {
  791. if (stop is null)
  792. {
  793. stop = new string[0];
  794. }
  795. string GetRole(ChatCompletionMessage message)
  796. {
  797. return message.Role == "user" ? "Human" : "Assistant";
  798. }
  799. string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
  800. var prompt = chatHistory + "### Assistant:";
  801. var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
  802. var completion = StreamCall(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
  803. repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
  804. return ConvertTextCompletionChunksToChat(completion);
  805. }
  806. public LLamaState SaveState()
  807. {
  808. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  809. ulong stateSize = NativeApi.llama_get_state_size(_ctx);
  810. byte[] llamaState = new byte[stateSize];
  811. ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState);
  812. if(nBytes > stateSize)
  813. {
  814. throw new RuntimeError("Failed to copy llama state data");
  815. }
  816. byte[] llamaStateCompact = new byte[nBytes];
  817. llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0);
  818. if (_verbose)
  819. {
  820. Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state");
  821. }
  822. return new LLamaState(new Queue<llama_token>(_eval_tokens), new Queue<float[]>(_eval_logits),
  823. llamaStateCompact, (int)nBytes);
  824. }
  825. public void LoadState(LLamaState state)
  826. {
  827. Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
  828. _eval_tokens = new Queue<llama_token>(state.EvalTokens);
  829. _eval_logits = new Queue<float[]>(state.EvalLogits);
  830. if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size)
  831. {
  832. throw new RuntimeError($"Failed to set llama state data");
  833. }
  834. }
  835. private static IEnumerable<float> LogitsToLogprobs(IEnumerable<float> logits)
  836. {
  837. var exps = logits.Select(x => (float)Math.Exp(x));
  838. var sumExps = exps.Sum();
  839. return exps.Select(x => (float)Math.Log(x / sumExps));
  840. }
  841. internal static int LongestTokenPrefix(IEnumerable<llama_token> a, IEnumerable<llama_token> b)
  842. {
  843. int longestPrefix = 0;
  844. foreach(var (x, y) in a.Zip(b, (x, y) => (x, y)))
  845. {
  846. if(x == y)
  847. {
  848. longestPrefix++;
  849. }
  850. else
  851. {
  852. break;
  853. }
  854. }
  855. return longestPrefix;
  856. }
  857. }
  858. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。

Contributors (1)