You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IModelParams.cs 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. using System;
  2. using System.Buffers;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text.Json;
  7. using System.Text.Json.Serialization;
  8. using LLama.Native;
  9. namespace LLama.Abstractions
  10. {
  11. /// <summary>
  12. /// The parameters for initializing a LLama model.
  13. /// </summary>
  14. public interface IModelParams
  15. {
  16. /// <summary>
  17. /// the GPU that is used for scratch and small tensors
  18. /// </summary>
  19. int MainGpu { get; set; }
  20. /// <summary>
  21. /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
  22. /// </summary>
  23. int GpuLayerCount { get; set; }
  24. /// <summary>
  25. /// Use mmap for faster loads (use_mmap)
  26. /// </summary>
  27. bool UseMemorymap { get; set; }
  28. /// <summary>
  29. /// Use mlock to keep model in memory (use_mlock)
  30. /// </summary>
  31. bool UseMemoryLock { get; set; }
  32. /// <summary>
  33. /// Model path (model)
  34. /// </summary>
  35. string ModelPath { get; set; }
  36. /// <summary>
  37. /// how split tensors should be distributed across GPUs
  38. /// </summary>
  39. TensorSplitsCollection TensorSplits { get; set; }
  40. /// <summary>
  41. /// Load vocab only (no weights)
  42. /// </summary>
  43. bool VocabOnly { get; set; }
  44. /// <summary>
  45. /// List of LoRA adapters to apply
  46. /// </summary>
  47. AdapterCollection LoraAdapters { get; }
  48. /// <summary>
  49. /// base model path for the lora adapter (lora_base)
  50. /// </summary>
  51. string LoraBase { get; set; }
  52. /// <summary>
  53. /// Override specific metadata items in the model
  54. /// </summary>
  55. List<MetadataOverride> MetadataOverrides { get; }
  56. }
  57. /// <summary>
  58. /// A LoRA adapter to apply to a model
  59. /// </summary>
  60. /// <param name="Path">Path to the LoRA file</param>
  61. /// <param name="Scale">Strength of this LoRA</param>
  62. public readonly record struct LoraAdapter(string Path, float Scale);
  63. /// <summary>
  64. /// A list of LoraAdapter objects
  65. /// </summary>
  66. public sealed class AdapterCollection
  67. : List<LoraAdapter>, IEquatable<AdapterCollection>
  68. {
  69. /// <inheritdoc />
  70. public bool Equals(AdapterCollection? other)
  71. {
  72. if (other == null)
  73. return false;
  74. return this.SequenceEqual(other);
  75. }
  76. /// <inheritdoc/>
  77. public override bool Equals(object? obj)
  78. {
  79. return Equals(obj as AdapterCollection);
  80. }
  81. /// <inheritdoc/>
  82. public override int GetHashCode()
  83. {
  84. unchecked
  85. {
  86. var hash = 17;
  87. for (var i = 0; i < Count; i++)
  88. {
  89. hash += this[i].GetHashCode();
  90. hash *= 7823;
  91. }
  92. return hash;
  93. }
  94. }
  95. }
  96. /// <summary>
  97. /// A fixed size array to set the tensor splits across multiple GPUs
  98. /// </summary>
  99. [JsonConverter(typeof(TensorSplitsCollectionConverter))]
  100. public sealed class TensorSplitsCollection
  101. : IEnumerable<float>
  102. {
  103. internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
  104. /// <summary>
  105. /// The size of this array
  106. /// </summary>
  107. public int Length => Splits.Length;
  108. /// <summary>
  109. /// Get or set the proportion of work to do on the given device.
  110. /// </summary>
  111. /// <remarks>"[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.</remarks>
  112. /// <param name="index"></param>
  113. /// <returns></returns>
  114. public float this[int index]
  115. {
  116. get => Splits[index];
  117. set => Splits[index] = value;
  118. }
  119. /// <summary>
  120. /// Create a new tensor splits collection, copying the given values
  121. /// </summary>
  122. /// <param name="splits"></param>
  123. /// <exception cref="ArgumentException"></exception>
  124. public TensorSplitsCollection(float[] splits)
  125. {
  126. if (splits.Length > Splits.Length)
  127. throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
  128. splits.CopyTo(Splits.AsSpan());
  129. }
  130. /// <summary>
  131. /// Create a new tensor splits collection with all values initialised to the default
  132. /// </summary>
  133. public TensorSplitsCollection()
  134. {
  135. }
  136. /// <summary>
  137. /// Set all values to zero
  138. /// </summary>
  139. public void Clear()
  140. {
  141. Array.Clear(Splits, 0, Splits.Length);
  142. }
  143. internal MemoryHandle Pin()
  144. {
  145. return Splits.AsMemory().Pin();
  146. }
  147. #region IEnumerator
  148. /// <inheritdoc />
  149. public IEnumerator<float> GetEnumerator()
  150. {
  151. return ((IEnumerable<float>)Splits).GetEnumerator();
  152. }
  153. /// <inheritdoc />
  154. IEnumerator IEnumerable.GetEnumerator()
  155. {
  156. return Splits.GetEnumerator();
  157. }
  158. #endregion
  159. }
  160. /// <summary>
  161. /// A JSON converter for <see cref="TensorSplitsCollection"/>
  162. /// </summary>
  163. public class TensorSplitsCollectionConverter
  164. : JsonConverter<TensorSplitsCollection>
  165. {
  166. /// <inheritdoc/>
  167. public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  168. {
  169. var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
  170. return new TensorSplitsCollection(arr);
  171. }
  172. /// <inheritdoc/>
  173. public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
  174. {
  175. JsonSerializer.Serialize(writer, value.Splits, options);
  176. }
  177. }
  178. /// <summary>
  179. /// An override for a single key/value pair in model metadata
  180. /// </summary>
  181. [JsonConverter(typeof(MetadataOverrideConverter))]
  182. public sealed record MetadataOverride
  183. {
  184. /// <summary>
  185. /// Get the key being overriden by this override
  186. /// </summary>
  187. public string Key { get; }
  188. internal LLamaModelKvOverrideType Type { get; }
  189. private readonly int _valueInt;
  190. private readonly float _valueFloat;
  191. private readonly bool _valueBool;
  192. /// <summary>
  193. /// Create a new override for an int key
  194. /// </summary>
  195. /// <param name="key"></param>
  196. /// <param name="value"></param>
  197. public MetadataOverride(string key, int value)
  198. {
  199. Key = key;
  200. _valueInt = value;
  201. Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
  202. }
  203. /// <summary>
  204. /// Create a new override for a float key
  205. /// </summary>
  206. /// <param name="key"></param>
  207. /// <param name="value"></param>
  208. public MetadataOverride(string key, float value)
  209. {
  210. Key = key;
  211. _valueFloat = value;
  212. Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
  213. }
  214. /// <summary>
  215. /// Create a new override for a boolean key
  216. /// </summary>
  217. /// <param name="key"></param>
  218. /// <param name="value"></param>
  219. public MetadataOverride(string key, bool value)
  220. {
  221. Key = key;
  222. _valueBool = value;
  223. Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
  224. }
  225. internal void WriteValue(ref LLamaModelMetadataOverride dest)
  226. {
  227. switch (Type)
  228. {
  229. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
  230. dest.IntValue = _valueInt;
  231. break;
  232. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
  233. dest.FloatValue = _valueFloat;
  234. break;
  235. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
  236. dest.BoolValue = _valueBool ? -1L : 0;
  237. break;
  238. default:
  239. throw new ArgumentOutOfRangeException();
  240. }
  241. }
  242. internal void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
  243. {
  244. switch (Type)
  245. {
  246. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
  247. writer.WriteNumberValue(_valueInt);
  248. break;
  249. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
  250. writer.WriteNumberValue(_valueFloat);
  251. break;
  252. case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
  253. writer.WriteBooleanValue(_valueBool);
  254. break;
  255. default:
  256. throw new ArgumentOutOfRangeException();
  257. }
  258. }
  259. }
  260. /// <summary>
  261. /// A JSON converter for <see cref="MetadataOverride"/>
  262. /// </summary>
  263. public class MetadataOverrideConverter
  264. : JsonConverter<MetadataOverride>
  265. {
  266. /// <inheritdoc/>
  267. public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  268. {
  269. var ktv = JsonSerializer.Deserialize<KeyTypeValue>(ref reader, options)!;
  270. return ((LLamaModelKvOverrideType)ktv.Type) switch
  271. {
  272. LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
  273. LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
  274. LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
  275. _ => throw new JsonException(),
  276. };
  277. }
  278. /// <inheritdoc/>
  279. public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
  280. {
  281. writer.WriteStartObject();
  282. {
  283. writer.WriteNumber("Type", (int)value.Type);
  284. writer.WriteString("Key", value.Key);
  285. writer.WritePropertyName("Value");
  286. value.WriteValue(writer, options);
  287. }
  288. writer.WriteEndObject();
  289. }
  290. private record KeyTypeValue(int Type, string Key, JsonElement Value);
  291. }
  292. }