You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaWeights.cs 8.7 kB

April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using LLama.Abstractions;
  7. using LLama.Exceptions;
  8. using LLama.Extensions;
  9. using LLama.Native;
  10. using Microsoft.Extensions.Logging;
  11. namespace LLama
  12. {
  13. /// <summary>
  14. /// A set of model weights, loaded into memory.
  15. /// </summary>
  16. public sealed class LLamaWeights
  17. : IDisposable
  18. {
  19. /// <summary>
  20. /// The native handle, which is used in the native APIs
  21. /// </summary>
  22. /// <remarks>Be careful how you use this!</remarks>
  23. public SafeLlamaModelHandle NativeHandle { get; }
  24. /// <summary>
  25. /// Total number of tokens in vocabulary of this model
  26. /// </summary>
  27. public int VocabCount => NativeHandle.VocabCount;
  28. /// <summary>
  29. /// Total number of tokens in the context
  30. /// </summary>
  31. public int ContextSize => NativeHandle.ContextSize;
  32. /// <summary>
  33. /// Get the size of this model in bytes
  34. /// </summary>
  35. public ulong SizeInBytes => NativeHandle.SizeInBytes;
  36. /// <summary>
  37. /// Get the number of parameters in this model
  38. /// </summary>
  39. public ulong ParameterCount => NativeHandle.ParameterCount;
  40. /// <summary>
  41. /// Dimension of embedding vectors
  42. /// </summary>
  43. public int EmbeddingSize => NativeHandle.EmbeddingSize;
  44. /// <summary>
  45. /// Get the special tokens of this model
  46. /// </summary>
  47. public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens;
  48. /// <summary>
  49. /// All metadata keys in this model
  50. /// </summary>
  51. public IReadOnlyDictionary<string, string> Metadata { get; set; }
  52. private LLamaWeights(SafeLlamaModelHandle weights)
  53. {
  54. NativeHandle = weights;
  55. Metadata = weights.ReadMetadata();
  56. }
  57. /// <summary>
  58. /// Load weights into memory
  59. /// </summary>
  60. /// <param name="params"></param>
  61. /// <returns></returns>
  62. public static LLamaWeights LoadFromFile(IModelParams @params)
  63. {
  64. using var pin = @params.ToLlamaModelParams(out var lparams);
  65. var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
  66. foreach (var adapter in @params.LoraAdapters)
  67. {
  68. if (string.IsNullOrEmpty(adapter.Path))
  69. continue;
  70. if (adapter.Scale <= 0)
  71. continue;
  72. weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
  73. }
  74. return new LLamaWeights(weights);
  75. }
  76. /// <summary>
  77. /// Load weights into memory
  78. /// </summary>
  79. /// <param name="params">Parameters to use to load the model</param>
  80. /// <param name="token">A cancellation token that can interrupt model loading</param>
  81. /// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
  82. /// <returns></returns>
  83. /// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
  84. /// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
  85. public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
  86. {
  87. // don't touch the @params object inside the task, it might be changed
  88. // externally! Save a copy of everything that we need later.
  89. var modelPath = @params.ModelPath;
  90. var loraBase = @params.LoraBase;
  91. var loraAdapters = @params.LoraAdapters.ToArray();
  92. // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
  93. // slightly smaller range to allow some space for reporting LoRA loading too.
  94. var modelLoadProgressRange = 1f;
  95. if (loraAdapters.Length > 0)
  96. modelLoadProgressRange = 0.9f;
  97. using (@params.ToLlamaModelParams(out var lparams))
  98. {
  99. #if !NETSTANDARD2_0
  100. // Overwrite the progress callback with one which polls the cancellation token and updates the progress object
  101. if (token.CanBeCanceled || progressReporter != null)
  102. {
  103. var internalCallback = lparams.progress_callback;
  104. lparams.progress_callback = (progress, ctx) =>
  105. {
  106. // Update the progress reporter (remapping the value into the smaller range).
  107. progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
  108. // If the user set a callback in the model params, call that and see if we should cancel
  109. if (internalCallback != null && !internalCallback(progress, ctx))
  110. return false;
  111. // Check the cancellation token
  112. if (token.IsCancellationRequested)
  113. return false;
  114. return true;
  115. };
  116. }
  117. #endif
  118. var model = await Task.Run(() =>
  119. {
  120. try
  121. {
  122. // Load the model
  123. var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
  124. // Apply the LoRA adapters
  125. for (var i = 0; i < loraAdapters.Length; i++)
  126. {
  127. // Interrupt applying LoRAs if the token is cancelled
  128. if (token.IsCancellationRequested)
  129. {
  130. weights.Dispose();
  131. token.ThrowIfCancellationRequested();
  132. }
  133. // Don't apply invalid adapters
  134. var adapter = loraAdapters[i];
  135. if (string.IsNullOrEmpty(adapter.Path))
  136. continue;
  137. if (adapter.Scale <= 0)
  138. continue;
  139. weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
  140. // Report progress. Model loading reported progress from 0 -> 0.9, use
  141. // the last 0.1 to represent all of the LoRA adapters being applied.
  142. progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
  143. }
  144. // Update progress reporter to indicate completion
  145. progressReporter?.Report(1);
  146. return new LLamaWeights(weights);
  147. }
  148. catch (LoadWeightsFailedException)
  149. {
  150. // Convert a LoadWeightsFailedException into a cancellation exception if possible.
  151. token.ThrowIfCancellationRequested();
  152. // Ok the weights failed to load for some reason other than cancellation.
  153. throw;
  154. }
  155. }, token);
  156. return model;
  157. }
  158. }
  159. /// <inheritdoc />
  160. public void Dispose()
  161. {
  162. NativeHandle.Dispose();
  163. }
  164. /// <summary>
  165. /// Create a llama_context using this model
  166. /// </summary>
  167. /// <param name="params"></param>
  168. /// <param name="logger"></param>
  169. /// <returns></returns>
  170. public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null)
  171. {
  172. return new LLamaContext(this, @params, logger);
  173. }
  174. /// <summary>
  175. /// Convert a string of text into tokens
  176. /// </summary>
  177. /// <param name="text"></param>
  178. /// <param name="add_bos"></param>
  179. /// <param name="encoding"></param>
  180. /// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
  181. /// <returns></returns>
  182. public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
  183. {
  184. return NativeHandle.Tokenize(text, add_bos, special, encoding);
  185. }
  186. }
  187. }