You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLlamaModelHandle.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using System;
  2. using System.Text;
  3. using LLama.Exceptions;
  4. namespace LLama.Native
  5. {
  6. /// <summary>
  7. /// A reference to a set of llama model weights
  8. /// </summary>
  9. public sealed class SafeLlamaModelHandle
  10. : SafeLLamaHandleBase
  11. {
  12. /// <summary>
  13. /// Total number of tokens in vocabulary of this model
  14. /// </summary>
  15. public int VocabCount { get; }
  16. /// <summary>
  17. /// Total number of tokens in the context
  18. /// </summary>
  19. public int ContextSize { get; }
  20. /// <summary>
  21. /// Dimension of embedding vectors
  22. /// </summary>
  23. public int EmbeddingSize { get; }
  24. internal SafeLlamaModelHandle(IntPtr handle)
  25. : base(handle)
  26. {
  27. VocabCount = NativeApi.llama_model_n_vocab(this);
  28. ContextSize = NativeApi.llama_model_n_ctx(this);
  29. EmbeddingSize = NativeApi.llama_model_n_embd(this);
  30. }
  31. /// <inheritdoc />
  32. protected override bool ReleaseHandle()
  33. {
  34. NativeApi.llama_free_model(handle);
  35. SetHandle(IntPtr.Zero);
  36. return true;
  37. }
  38. /// <summary>
  39. /// Load a model from the given file path into memory
  40. /// </summary>
  41. /// <param name="modelPath"></param>
  42. /// <param name="lparams"></param>
  43. /// <returns></returns>
  44. /// <exception cref="RuntimeError"></exception>
  45. public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
  46. {
  47. var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
  48. if (model_ptr == IntPtr.Zero)
  49. throw new RuntimeError($"Failed to load model {modelPath}.");
  50. return new SafeLlamaModelHandle(model_ptr);
  51. }
  52. #region LoRA
  53. /// <summary>
  54. /// Apply a LoRA adapter to a loaded model
  55. /// </summary>
  56. /// <param name="lora"></param>
  57. /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
  58. /// adapter. Can be NULL to use the current loaded model.</param>
  59. /// <param name="threads"></param>
  60. /// <exception cref="RuntimeError"></exception>
  61. public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
  62. {
  63. var err = NativeApi.llama_model_apply_lora_from_file(
  64. this,
  65. lora,
  66. string.IsNullOrEmpty(modelBase) ? null : modelBase,
  67. threads
  68. );
  69. if (err != 0)
  70. throw new RuntimeError("Failed to apply lora adapter.");
  71. }
  72. #endregion
  73. #region tokenize
  74. /// <summary>
  75. /// Convert a single llama token into string bytes
  76. /// </summary>
  77. /// <param name="llama_token"></param>
  78. /// <returns></returns>
  79. public ReadOnlySpan<byte> TokenToSpan(int llama_token)
  80. {
  81. unsafe
  82. {
  83. var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
  84. var terminator = bytes.IndexOf((byte)0);
  85. return bytes.Slice(0, terminator);
  86. }
  87. }
  88. /// <summary>
  89. /// Convert a single llama token into a string
  90. /// </summary>
  91. /// <param name="llama_token"></param>
  92. /// <param name="encoding">Encoding to use to decode the bytes into a string</param>
  93. /// <returns></returns>
  94. public string TokenToString(int llama_token, Encoding encoding)
  95. {
  96. var span = TokenToSpan(llama_token);
  97. if (span.Length == 0)
  98. return "";
  99. unsafe
  100. {
  101. fixed (byte* ptr = &span[0])
  102. {
  103. return encoding.GetString(ptr, span.Length);
  104. }
  105. }
  106. }
  107. /// <summary>
  108. /// Convert a string of text into tokens
  109. /// </summary>
  110. /// <param name="text"></param>
  111. /// <param name="add_bos"></param>
  112. /// <param name="encoding"></param>
  113. /// <returns></returns>
  114. public int[] Tokenize(string text, bool add_bos, Encoding encoding)
  115. {
  116. // Convert string to bytes, adding one extra byte to the end (null terminator)
  117. var bytesCount = encoding.GetByteCount(text);
  118. var bytes = new byte[bytesCount + 1];
  119. unsafe
  120. {
  121. fixed (char* charPtr = text)
  122. fixed (byte* bytePtr = &bytes[0])
  123. {
  124. encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
  125. }
  126. }
  127. unsafe
  128. {
  129. fixed (byte* bytesPtr = &bytes[0])
  130. {
  131. // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
  132. var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
  133. // Tokenize again, this time outputting into an array of exactly the right size
  134. var tokens = new int[count];
  135. fixed (int* tokensPtr = &tokens[0])
  136. {
  137. NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
  138. return tokens;
  139. }
  140. }
  141. }
  142. }
  143. #endregion
  144. #region context
  145. /// <summary>
  146. /// Create a new context for this model
  147. /// </summary>
  148. /// <param name="params"></param>
  149. /// <returns></returns>
  150. public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
  151. {
  152. return SafeLLamaContextHandle.Create(this, @params);
  153. }
  154. #endregion
  155. }
  156. }