You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLlamaModelHandle.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. using System;
  2. using System.Diagnostics;
  3. using System.Text;
  4. using LLama.Exceptions;
  5. namespace LLama.Native
  6. {
  7. /// <summary>
  8. /// A reference to a set of llama model weights
  9. /// </summary>
  10. public sealed class SafeLlamaModelHandle
  11. : SafeLLamaHandleBase
  12. {
  13. /// <summary>
  14. /// Total number of tokens in vocabulary of this model
  15. /// </summary>
  16. public int VocabCount { get; }
  17. /// <summary>
  18. /// Total number of tokens in the context
  19. /// </summary>
  20. public int ContextSize { get; }
  21. /// <summary>
  22. /// Dimension of embedding vectors
  23. /// </summary>
  24. public int EmbeddingSize { get; }
  25. internal SafeLlamaModelHandle(IntPtr handle)
  26. : base(handle)
  27. {
  28. VocabCount = NativeApi.llama_n_vocab_from_model(this);
  29. ContextSize = NativeApi.llama_n_ctx_from_model(this);
  30. EmbeddingSize = NativeApi.llama_n_embd_from_model(this);
  31. }
  32. /// <inheritdoc />
  33. protected override bool ReleaseHandle()
  34. {
  35. NativeApi.llama_free_model(handle);
  36. SetHandle(IntPtr.Zero);
  37. return true;
  38. }
  39. /// <summary>
  40. /// Load a model from the given file path into memory
  41. /// </summary>
  42. /// <param name="modelPath"></param>
  43. /// <param name="lparams"></param>
  44. /// <returns></returns>
  45. /// <exception cref="RuntimeError"></exception>
  46. public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
  47. {
  48. var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
  49. if (model_ptr == IntPtr.Zero)
  50. throw new RuntimeError($"Failed to load model {modelPath}.");
  51. return new SafeLlamaModelHandle(model_ptr);
  52. }
  53. #region LoRA
  54. /// <summary>
  55. /// Apply a LoRA adapter to a loaded model
  56. /// </summary>
  57. /// <param name="lora"></param>
  58. /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
  59. /// adapter. Can be NULL to use the current loaded model.</param>
  60. /// <param name="threads"></param>
  61. /// <exception cref="RuntimeError"></exception>
  62. public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
  63. {
  64. var err = NativeApi.llama_model_apply_lora_from_file(
  65. this,
  66. lora,
  67. string.IsNullOrEmpty(modelBase) ? null : modelBase,
  68. threads
  69. );
  70. if (err != 0)
  71. throw new RuntimeError("Failed to apply lora adapter.");
  72. }
  73. #endregion
  74. #region tokenize
  75. /// <summary>
  76. /// Convert a single llama token into string bytes
  77. /// </summary>
  78. /// <param name="llama_token"></param>
  79. /// <returns></returns>
  80. public ReadOnlySpan<byte> TokenToSpan(int llama_token)
  81. {
  82. unsafe
  83. {
  84. var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
  85. var terminator = bytes.IndexOf((byte)0);
  86. return bytes.Slice(0, terminator);
  87. }
  88. }
  89. /// <summary>
  90. /// Convert a single llama token into a string
  91. /// </summary>
  92. /// <param name="llama_token"></param>
  93. /// <param name="encoding">Encoding to use to decode the bytes into a string</param>
  94. /// <returns></returns>
  95. public string TokenToString(int llama_token, Encoding encoding)
  96. {
  97. var span = TokenToSpan(llama_token);
  98. if (span.Length == 0)
  99. return "";
  100. unsafe
  101. {
  102. fixed (byte* ptr = &span[0])
  103. {
  104. return encoding.GetString(ptr, span.Length);
  105. }
  106. }
  107. }
  108. /// <summary>
  109. /// Convert a string of text into tokens
  110. /// </summary>
  111. /// <param name="text"></param>
  112. /// <param name="add_bos"></param>
  113. /// <param name="encoding"></param>
  114. /// <returns></returns>
  115. public int[] Tokenize(string text, bool add_bos, Encoding encoding)
  116. {
  117. // Convert string to bytes, adding one extra byte to the end (null terminator)
  118. var bytesCount = encoding.GetByteCount(text);
  119. var bytes = new byte[bytesCount + 1];
  120. unsafe
  121. {
  122. fixed (char* charPtr = text)
  123. fixed (byte* bytePtr = &bytes[0])
  124. {
  125. encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
  126. }
  127. }
  128. unsafe
  129. {
  130. fixed (byte* bytesPtr = &bytes[0])
  131. {
  132. // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
  133. var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
  134. // Tokenize again, this time outputting into an array of exactly the right size
  135. var tokens = new int[count];
  136. fixed (int* tokensPtr = &tokens[0])
  137. {
  138. NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
  139. return tokens;
  140. }
  141. }
  142. }
  143. }
  144. #endregion
  145. #region context
  146. /// <summary>
  147. /// Create a new context for this model
  148. /// </summary>
  149. /// <param name="params"></param>
  150. /// <returns></returns>
  151. public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
  152. {
  153. return SafeLLamaContextHandle.Create(this, @params);
  154. }
  155. #endregion
  156. }
  157. }