You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLlamaModelHandle.cs 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. using System;
  2. using System.Diagnostics;
  3. using System.Text;
  4. using LLama.Exceptions;
  5. namespace LLama.Native
  6. {
  7. /// <summary>
  8. /// A reference to a set of llama model weights
  9. /// </summary>
  10. public sealed class SafeLlamaModelHandle
  11. : SafeLLamaHandleBase
  12. {
  13. /// <summary>
  14. /// Total number of tokens in vocabulary of this model
  15. /// </summary>
  16. public int VocabCount { get; }
  17. /// <summary>
  18. /// Total number of tokens in the context
  19. /// </summary>
  20. public int ContextSize { get; }
  21. /// <summary>
  22. /// Dimension of embedding vectors
  23. /// </summary>
  24. public int EmbeddingSize { get; }
  25. internal SafeLlamaModelHandle(IntPtr handle)
  26. : base(handle)
  27. {
  28. VocabCount = NativeApi.llama_model_n_vocab(this);
  29. ContextSize = NativeApi.llama_model_n_ctx(this);
  30. EmbeddingSize = NativeApi.llama_model_n_embd(this);
  31. }
  32. /// <inheritdoc />
  33. protected override bool ReleaseHandle()
  34. {
  35. NativeApi.llama_free_model(handle);
  36. SetHandle(IntPtr.Zero);
  37. return true;
  38. }
  39. /// <summary>
  40. /// Load a model from the given file path into memory
  41. /// </summary>
  42. /// <param name="modelPath"></param>
  43. /// <param name="lparams"></param>
  44. /// <returns></returns>
  45. /// <exception cref="RuntimeError"></exception>
  46. public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
  47. {
  48. var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
  49. if (model_ptr == IntPtr.Zero)
  50. throw new RuntimeError($"Failed to load model {modelPath}.");
  51. return new SafeLlamaModelHandle(model_ptr);
  52. }
  53. #region LoRA
  54. /// <summary>
  55. /// Apply a LoRA adapter to a loaded model
  56. /// </summary>
  57. /// <param name="lora"></param>
  58. /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
  59. /// adapter. Can be NULL to use the current loaded model.</param>
  60. /// <param name="threads"></param>
  61. /// <exception cref="RuntimeError"></exception>
  62. public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
  63. {
  64. var err = NativeApi.llama_model_apply_lora_from_file(
  65. this,
  66. lora,
  67. string.IsNullOrEmpty(modelBase) ? null : modelBase,
  68. threads
  69. );
  70. if (err != 0)
  71. throw new RuntimeError("Failed to apply lora adapter.");
  72. }
  73. #endregion
  74. #region tokenize
  75. /// <summary>
  76. /// Convert a single llama token into bytes
  77. /// </summary>
  78. /// <param name="llama_token">Token to decode</param>
  79. /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
  80. /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
  81. public int TokenToSpan(int llama_token, Span<byte> dest)
  82. {
  83. unsafe
  84. {
  85. fixed (byte* destPtr = dest)
  86. {
  87. var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length);
  88. return Math.Abs(length);
  89. }
  90. }
  91. }
  92. /// <summary>
  93. /// Convert a single llama token into a string
  94. /// </summary>
  95. /// <param name="llama_token"></param>
  96. /// <param name="encoding">Encoding to use to decode the bytes into a string</param>
  97. /// <returns></returns>
  98. public string TokenToString(int llama_token, Encoding encoding)
  99. {
  100. unsafe
  101. {
  102. var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
  103. if (length == 0)
  104. return "";
  105. Span<byte> bytes = stackalloc byte[-length];
  106. fixed (byte* bytePtr = bytes)
  107. {
  108. var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
  109. Debug.Assert(written == bytes.Length);
  110. return encoding.GetString(bytePtr, bytes.Length);
  111. }
  112. }
  113. }
  114. /// <summary>
  115. /// Append a single llama token to a string builder
  116. /// </summary>
  117. /// <param name="llama_token">Token to decode</param>
  118. /// <param name="encoding"></param>
  119. /// <param name="dest">string builder to append the result to</param>
  120. public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest)
  121. {
  122. unsafe
  123. {
  124. var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
  125. if (length == 0)
  126. return;
  127. Span<byte> bytes = stackalloc byte[-length];
  128. fixed (byte* bytePtr = bytes)
  129. {
  130. // Decode into bytes
  131. var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
  132. Debug.Assert(written == bytes.Length);
  133. // Decode into chars
  134. var charCount = encoding.GetCharCount(bytePtr, bytes.Length);
  135. Span<char> chars = stackalloc char[charCount];
  136. fixed (char* charPtr = chars)
  137. encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length);
  138. // Write it to the output
  139. for (var i = 0; i < chars.Length; i++)
  140. dest.Append(chars[i]);
  141. }
  142. }
  143. }
  144. /// <summary>
  145. /// Convert a string of text into tokens
  146. /// </summary>
  147. /// <param name="text"></param>
  148. /// <param name="add_bos"></param>
  149. /// <param name="encoding"></param>
  150. /// <returns></returns>
  151. public int[] Tokenize(string text, bool add_bos, Encoding encoding)
  152. {
  153. // Convert string to bytes, adding one extra byte to the end (null terminator)
  154. var bytesCount = encoding.GetByteCount(text);
  155. var bytes = new byte[bytesCount + 1];
  156. unsafe
  157. {
  158. fixed (char* charPtr = text)
  159. fixed (byte* bytePtr = &bytes[0])
  160. {
  161. encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
  162. }
  163. }
  164. unsafe
  165. {
  166. fixed (byte* bytesPtr = &bytes[0])
  167. {
  168. // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
  169. var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
  170. // Tokenize again, this time outputting into an array of exactly the right size
  171. var tokens = new int[count];
  172. fixed (int* tokensPtr = &tokens[0])
  173. {
  174. NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
  175. return tokens;
  176. }
  177. }
  178. }
  179. }
  180. #endregion
  181. #region context
  182. /// <summary>
  183. /// Create a new context for this model
  184. /// </summary>
  185. /// <param name="params"></param>
  186. /// <returns></returns>
  187. public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
  188. {
  189. return SafeLLamaContextHandle.Create(this, @params);
  190. }
  191. #endregion
  192. }
  193. }