You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IModelParams.cs 9.2 kB

2 years ago

  1. using System;
  2. using System.Buffers;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text.Json;
  7. using System.Text.Json.Serialization;
  8. using LLama.Common;
  9. using LLama.Native;
  10. namespace LLama.Abstractions
  11. {
  12. /// <summary>
  13. /// The parameters for initializing a LLama model.
  14. /// </summary>
  15. public interface IModelParams
  16. {
  17. /// <summary>
  18. /// the GPU that is used for scratch and small tensors
  19. /// </summary>
  20. int MainGpu { get; set; }
  21. /// <summary>
  22. /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
  23. /// </summary>
  24. int GpuLayerCount { get; set; }
  25. /// <summary>
  26. /// Use mmap for faster loads (use_mmap)
  27. /// </summary>
  28. bool UseMemorymap { get; set; }
  29. /// <summary>
  30. /// Use mlock to keep model in memory (use_mlock)
  31. /// </summary>
  32. bool UseMemoryLock { get; set; }
  33. /// <summary>
  34. /// Model path (model)
  35. /// </summary>
  36. string ModelPath { get; set; }
  37. /// <summary>
  38. /// how split tensors should be distributed across GPUs
  39. /// </summary>
  40. TensorSplitsCollection TensorSplits { get; set; }
  41. /// <summary>
  42. /// Load vocab only (no weights)
  43. /// </summary>
  44. bool VocabOnly { get; set; }
  45. /// <summary>
  46. /// List of LoRA adapters to apply
  47. /// </summary>
  48. AdapterCollection LoraAdapters { get; }
  49. /// <summary>
  50. /// base model path for the lora adapter (lora_base)
  51. /// </summary>
  52. string LoraBase { get; set; }
  53. /// <summary>
  54. /// Override specific metadata items in the model
  55. /// </summary>
  56. List<MetadataOverride> MetadataOverrides { get; }
  57. }
  58. /// <summary>
  59. /// A LoRA adapter to apply to a model
  60. /// </summary>
  61. /// <param name="Path">Path to the LoRA file</param>
  62. /// <param name="Scale">Strength of this LoRA</param>
  63. public readonly record struct LoraAdapter(string Path, float Scale);
  64. /// <summary>
  65. /// A list of LoraAdapter objects
  66. /// </summary>
  67. public sealed class AdapterCollection
  68. : List<LoraAdapter>, IEquatable<AdapterCollection>
  69. {
  70. /// <inheritdoc />
  71. public bool Equals(AdapterCollection? other)
  72. {
  73. if (other == null)
  74. return false;
  75. return this.SequenceEqual(other);
  76. }
  77. /// <inheritdoc/>
  78. public override bool Equals(object? obj)
  79. {
  80. return Equals(obj as AdapterCollection);
  81. }
  82. /// <inheritdoc/>
  83. public override int GetHashCode()
  84. {
  85. unchecked
  86. {
  87. var hash = 17;
  88. for (var i = 0; i < Count; i++)
  89. {
  90. hash += this[i].GetHashCode();
  91. hash *= 7823;
  92. }
  93. return hash;
  94. }
  95. }
  96. }
  97. /// <summary>
  98. /// A fixed size array to set the tensor splits across multiple GPUs
  99. /// </summary>
  100. [JsonConverter(typeof(TensorSplitsCollectionConverter))]
  101. public sealed class TensorSplitsCollection
  102. : IEnumerable<float>
  103. {
  104. internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
  105. /// <summary>
  106. /// The size of this array
  107. /// </summary>
  108. public int Length => Splits.Length;
  109. /// <summary>
  110. /// Get or set the proportion of work to do on the given device.
  111. /// </summary>
  112. /// <remarks>"[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.</remarks>
  113. /// <param name="index"></param>
  114. /// <returns></returns>
  115. public float this[int index]
  116. {
  117. get => Splits[index];
  118. set => Splits[index] = value;
  119. }
  120. /// <summary>
  121. /// Create a new tensor splits collection, copying the given values
  122. /// </summary>
  123. /// <param name="splits"></param>
  124. /// <exception cref="ArgumentException"></exception>
  125. public TensorSplitsCollection(float[] splits)
  126. {
  127. if (splits.Length > Splits.Length)
  128. throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
  129. splits.CopyTo(Splits.AsSpan());
  130. }
  131. /// <summary>
  132. /// Create a new tensor splits collection with all values initialised to the default
  133. /// </summary>
  134. public TensorSplitsCollection()
  135. {
  136. }
  137. /// <summary>
  138. /// Set all values to zero
  139. /// </summary>
  140. public void Clear()
  141. {
  142. Array.Clear(Splits, 0, Splits.Length);
  143. }
  144. internal MemoryHandle Pin()
  145. {
  146. return Splits.AsMemory().Pin();
  147. }
  148. #region IEnumerator
  149. /// <inheritdoc />
  150. public IEnumerator<float> GetEnumerator()
  151. {
  152. return ((IEnumerable<float>)Splits).GetEnumerator();
  153. }
  154. /// <inheritdoc />
  155. IEnumerator IEnumerable.GetEnumerator()
  156. {
  157. return Splits.GetEnumerator();
  158. }
  159. #endregion
  160. }
  161. /// <summary>
  162. /// A JSON converter for <see cref="TensorSplitsCollection"/>
  163. /// </summary>
  164. public class TensorSplitsCollectionConverter
  165. : JsonConverter<TensorSplitsCollection>
  166. {
  167. /// <inheritdoc/>
  168. public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  169. {
  170. var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
  171. return new TensorSplitsCollection(arr);
  172. }
  173. /// <inheritdoc/>
  174. public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
  175. {
  176. JsonSerializer.Serialize(writer, value.Splits, options);
  177. }
  178. }
  179. /// <summary>
  180. /// An override for a single key/value pair in model metadata
  181. /// </summary>
  182. [JsonConverter(typeof(MetadataOverrideConverter))]
  183. public abstract record MetadataOverride
  184. {
  185. /// <summary>
  186. /// Create a new override for an int key
  187. /// </summary>
  188. /// <param name="key"></param>
  189. /// <param name="value"></param>
  190. /// <returns></returns>
  191. public static MetadataOverride Create(string key, int value)
  192. {
  193. return new IntOverride(key, value);
  194. }
  195. /// <summary>
  196. /// Create a new override for a float key
  197. /// </summary>
  198. /// <param name="key"></param>
  199. /// <param name="value"></param>
  200. /// <returns></returns>
  201. public static MetadataOverride Create(string key, float value)
  202. {
  203. return new FloatOverride(key, value);
  204. }
  205. /// <summary>
  206. /// Create a new override for a boolean key
  207. /// </summary>
  208. /// <param name="key"></param>
  209. /// <param name="value"></param>
  210. /// <returns></returns>
  211. public static MetadataOverride Create(string key, bool value)
  212. {
  213. return new BoolOverride(key, value);
  214. }
  215. internal abstract void Write(ref LLamaModelMetadataOverride dest);
  216. /// <summary>
  217. /// Get the key being overriden by this override
  218. /// </summary>
  219. public abstract string Key { get; init; }
  220. private record IntOverride(string Key, int Value) : MetadataOverride
  221. {
  222. internal override void Write(ref LLamaModelMetadataOverride dest)
  223. {
  224. dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
  225. dest.IntValue = Value;
  226. }
  227. }
  228. private record FloatOverride(string Key, float Value) : MetadataOverride
  229. {
  230. internal override void Write(ref LLamaModelMetadataOverride dest)
  231. {
  232. dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
  233. dest.FloatValue = Value;
  234. }
  235. }
  236. private record BoolOverride(string Key, bool Value) : MetadataOverride
  237. {
  238. internal override void Write(ref LLamaModelMetadataOverride dest)
  239. {
  240. dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
  241. dest.BoolValue = Value ? -1 : 0;
  242. }
  243. }
  244. }
  245. public class MetadataOverrideConverter
  246. : JsonConverter<MetadataOverride>
  247. {
  248. /// <inheritdoc/>
  249. public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  250. {
  251. throw new NotImplementedException();
  252. //var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
  253. //return new TensorSplitsCollection(arr);
  254. }
  255. /// <inheritdoc/>
  256. public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
  257. {
  258. throw new NotImplementedException();
  259. //JsonSerializer.Serialize(writer, value.Splits, options);
  260. }
  261. }
  262. }