You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaModel.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. using LLama.Exceptions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using System.IO;
  8. using System.IO.MemoryMappedFiles;
  9. using LLama.Common;
  10. using System.Runtime.InteropServices;
  11. using LLama.Extensions;
  12. using Microsoft.Win32.SafeHandles;
  13. using LLama.Abstractions;
  14. namespace LLama
  15. {
  16. using llama_token = Int32;
  17. /// <summary>
  18. /// The abstraction of a LLama model, which holds the context in the native library.
  19. /// </summary>
  20. public class LLamaModel: IDisposable
  21. {
  22. // TODO: expose more properties.
  23. ILLamaLogger? _logger;
  24. Encoding _encoding;
  25. SafeLLamaContextHandle _ctx;
  26. /// <summary>
  27. /// The context size.
  28. /// </summary>
  29. public int ContextSize { get; }
  30. /// <summary>
  31. /// The model params set for this model.
  32. /// </summary>
  33. public IModelParams Params { get; set; }
  34. /// <summary>
  35. /// The native handle, which is used to be passed to the native APIs. Please avoid using it
  36. /// unless you know what is the usage of the Native API.
  37. /// </summary>
  38. public SafeLLamaContextHandle NativeHandle => _ctx;
  39. /// <summary>
  40. /// The encoding set for this model to deal with text input.
  41. /// </summary>
  42. public Encoding Encoding => _encoding;
  43. /// <summary>
  44. ///
  45. /// </summary>
  46. /// <param name="Params">Model params.</param>
  47. /// <param name="encoding">Encoding to deal with text input.</param>
  48. /// <param name="logger">The logger.</param>
  49. public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
  50. {
  51. _logger = logger;
  52. this.Params = Params;
  53. _encoding = Encoding.GetEncoding(encoding);
  54. _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
  55. _ctx = Utils.InitLLamaContextFromModelParams(this.Params);
  56. ContextSize = NativeApi.llama_n_ctx(_ctx);
  57. }
  58. /// <summary>
  59. /// Tokenize a string.
  60. /// </summary>
  61. /// <param name="text"></param>
  62. /// <param name="addBos">Whether to add a bos to the text.</param>
  63. /// <returns></returns>
  64. public llama_token[] Tokenize(string text, bool addBos = true)
  65. {
  66. return _ctx.Tokenize(text, addBos, _encoding);
  67. }
  68. /// <summary>
  69. /// Detokenize the tokens to text.
  70. /// </summary>
  71. /// <param name="tokens"></param>
  72. /// <returns></returns>
  73. public string DeTokenize(IEnumerable<llama_token> tokens)
  74. {
  75. StringBuilder sb = new();
  76. foreach(var token in tokens)
  77. sb.Append(_ctx.TokenToString(token, _encoding));
  78. return sb.ToString();
  79. }
  80. /// <summary>
  81. /// Save the state to specified path.
  82. /// </summary>
  83. /// <param name="filename"></param>
  84. public void SaveState(string filename)
  85. {
  86. // Delete that file before overwriting it
  87. if (File.Exists(filename))
  88. File.Delete(filename);
  89. // Estimate size of state to write to disk, this is always equal to or greater than the actual size
  90. var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx);
  91. // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
  92. long writtenBytes;
  93. using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize))
  94. using (var view = file.CreateViewAccessor(0, estimatedStateSize))
  95. {
  96. unsafe
  97. {
  98. byte* ptr = null;
  99. view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
  100. writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr);
  101. view.SafeMemoryMappedViewHandle.ReleasePointer();
  102. }
  103. }
  104. // Truncate the file to the actual size of data that was written
  105. using (var fileStream = new FileStream(filename, FileMode.Open))
  106. fileStream.SetLength(writtenBytes);
  107. }
  108. /// <summary>
  109. /// Get the state data as a byte array.
  110. /// </summary>
  111. /// <returns></returns>
  112. [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
  113. public byte[] GetStateData()
  114. {
  115. var stateSize = NativeApi.llama_get_state_size(_ctx);
  116. byte[] stateMemory = new byte[stateSize];
  117. NativeApi.llama_copy_state_data(_ctx, stateMemory);
  118. return stateMemory;
  119. }
  120. /// <summary>
  121. /// Get the state data as an opaque handle
  122. /// </summary>
  123. /// <returns></returns>
  124. public State GetState()
  125. {
  126. var stateSize = NativeApi.llama_get_state_size(_ctx);
  127. unsafe
  128. {
  129. var bigMemory = Marshal.AllocHGlobal((nint)stateSize);
  130. var smallMemory = IntPtr.Zero;
  131. try
  132. {
  133. // Copy the state data into "big memory", discover the actual size required
  134. var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);
  135. // Allocate a smaller buffer
  136. smallMemory = Marshal.AllocHGlobal((nint)actualSize);
  137. // Copy into the smaller buffer and free the large one to save excess memory usage
  138. Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
  139. Marshal.FreeHGlobal(bigMemory);
  140. bigMemory = IntPtr.Zero;
  141. return new State(smallMemory);
  142. }
  143. catch
  144. {
  145. if (bigMemory != IntPtr.Zero)
  146. Marshal.FreeHGlobal(bigMemory);
  147. if (smallMemory != IntPtr.Zero)
  148. Marshal.FreeHGlobal(smallMemory);
  149. throw;
  150. }
  151. }
  152. }
  153. /// <summary>
  154. /// Load the state from specified path.
  155. /// </summary>
  156. /// <param name="filename"></param>
  157. /// <exception cref="RuntimeError"></exception>
  158. public void LoadState(string filename)
  159. {
  160. // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
  161. using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
  162. using (var view = file.CreateViewAccessor())
  163. {
  164. unsafe
  165. {
  166. byte* ptr = null;
  167. view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
  168. NativeApi.llama_set_state_data(_ctx, ptr);
  169. view.SafeMemoryMappedViewHandle.ReleasePointer();
  170. }
  171. }
  172. }
  173. /// <summary>
  174. /// Load the state from memory.
  175. /// </summary>
  176. /// <param name="stateData"></param>
  177. /// <exception cref="RuntimeError"></exception>
  178. public void LoadState(byte[] stateData)
  179. {
  180. int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
  181. if (stateData.Length > stateSize)
  182. {
  183. throw new RuntimeError("Failed to validate state size.");
  184. }
  185. NativeApi.llama_set_state_data(_ctx, stateData);
  186. }
  187. /// <summary>
  188. /// Load the state from memory.
  189. /// </summary>
  190. /// <param name="state"></param>
  191. /// <exception cref="RuntimeError"></exception>
  192. public void LoadState(State state)
  193. {
  194. unsafe
  195. {
  196. NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
  197. }
  198. }
  199. /// <summary>
  200. /// Perform the sampling. Please don't use it unless you fully know what it does.
  201. /// </summary>
  202. /// <param name="candidates"></param>
  203. /// <param name="mirostat_mu"></param>
  204. /// <param name="temperature"></param>
  205. /// <param name="mirostat"></param>
  206. /// <param name="mirostatTau"></param>
  207. /// <param name="mirostatEta"></param>
  208. /// <param name="topK"></param>
  209. /// <param name="topP"></param>
  210. /// <param name="tfsZ"></param>
  211. /// <param name="typicalP"></param>
  212. /// <returns></returns>
  213. public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
  214. float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
  215. {
  216. llama_token id;
  217. if (temperature <= 0)
  218. {
  219. // Greedy sampling
  220. id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
  221. }
  222. else
  223. {
  224. var mu = mirostat_mu ?? (2 * mirostatTau);
  225. {
  226. if (mirostat == MirostatType.Mirostat)
  227. {
  228. const int mirostat_m = 100;
  229. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  230. id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
  231. }
  232. else if (mirostat == MirostatType.Mirostat2)
  233. {
  234. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  235. id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
  236. }
  237. else
  238. {
  239. // Temperature sampling
  240. SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
  241. SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
  242. SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
  243. SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
  244. SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
  245. id = SamplingApi.llama_sample_token(_ctx, candidates);
  246. }
  247. }
  248. mirostat_mu = mu;
  249. }
  250. return id;
  251. }
  252. /// <summary>
  253. /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
  254. /// </summary>
  255. /// <param name="lastTokens"></param>
  256. /// <param name="logitBias"></param>
  257. /// <param name="repeatLastTokensCount"></param>
  258. /// <param name="repeatPenalty"></param>
  259. /// <param name="alphaFrequency"></param>
  260. /// <param name="alphaPresence"></param>
  261. /// <param name="penalizeNL"></param>
  262. /// <returns></returns>
  263. public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
  264. int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
  265. bool penalizeNL = true)
  266. {
  267. var n_vocab = _ctx.VocabCount;
  268. var logits = _ctx.GetLogits();
  269. // Apply params.logit_bias map
  270. if(logitBias is not null)
  271. {
  272. foreach (var (key, value) in logitBias)
  273. {
  274. logits[key] += value;
  275. }
  276. }
  277. var candidates = new LLamaTokenData[n_vocab];
  278. for (llama_token token_id = 0; token_id < n_vocab; token_id++)
  279. candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
  280. LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
  281. // Apply penalties
  282. float nl_logit = logits[NativeApi.llama_token_nl()];
  283. int lastTokensCount = lastTokens.Count();
  284. var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize);
  285. SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
  286. lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
  287. (ulong)last_n_repeat, repeatPenalty);
  288. SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
  289. lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
  290. (ulong)last_n_repeat, alphaFrequency, alphaPresence);
  291. if (!penalizeNL)
  292. {
  293. logits[NativeApi.llama_token_nl()] = nl_logit;
  294. }
  295. return candidates_p;
  296. }
  297. /// <summary>
  298. ///
  299. /// </summary>
  300. /// <param name="tokens"></param>
  301. /// <param name="pastTokensCount"></param>
  302. /// <returns>The updated `pastTokensCount`.</returns>
  303. /// <exception cref="RuntimeError"></exception>
  304. public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
  305. {
  306. int total = tokens.Length;
  307. for(int i = 0; i < total; i += Params.BatchSize)
  308. {
  309. int n_eval = total - i;
  310. if(n_eval > Params.BatchSize)
  311. {
  312. n_eval = Params.BatchSize;
  313. }
  314. if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
  315. {
  316. _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
  317. throw new RuntimeError("Failed to eval.");
  318. }
  319. pastTokensCount += n_eval;
  320. }
  321. return pastTokensCount;
  322. }
  323. // TODO: add comment
  324. internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
  325. {
  326. foreach(var id in ids)
  327. yield return _ctx.TokenToString(id, _encoding);
  328. }
  329. /// <inheritdoc />
  330. public virtual void Dispose()
  331. {
  332. _ctx.Dispose();
  333. }
  334. /// <summary>
  335. /// The state of this model, which can be reloaded later
  336. /// </summary>
  337. public class State
  338. : SafeHandleZeroOrMinusOneIsInvalid
  339. {
  340. internal State(IntPtr memory)
  341. : base(true)
  342. {
  343. SetHandle(memory);
  344. }
  345. /// <inheritdoc />
  346. protected override bool ReleaseHandle()
  347. {
  348. Marshal.FreeHGlobal(handle);
  349. return true;
  350. }
  351. }
  352. }
  353. }