You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IModelParams.cs 11 kB

2 years ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. using System;
  2. using System.Buffers;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.ComponentModel;
  6. using System.Linq;
  7. using System.Text.Json;
  8. using System.Text.Json.Serialization;
  9. using LLama.Native;
  10. namespace LLama.Abstractions
  11. {
  12. /// <summary>
  13. /// The parameters for initializing a LLama model.
  14. /// </summary>
  15. public interface IModelParams
  16. {
  17. /// <summary>
  18. /// main_gpu interpretation depends on split_mode:
  19. /// <list type="bullet">
  20. /// <item>
  21. /// <term>None</term>
  22. /// <description>The GPU that is used for the entire mode.</description>
  23. /// </item>
  24. /// <item>
  25. /// <term>Row</term>
  26. /// <description>The GPU that is used for small tensors and intermediate results.</description>
  27. /// </item>
  28. /// <item>
  29. /// <term>Layer</term>
  30. /// <description>Ignored.</description>
  31. /// </item>
  32. /// </list>
  33. /// </summary>
  34. int MainGpu { get; set; }
  35. /// <summary>
  36. /// How to split the model across multiple GPUs
  37. /// </summary>
  38. GPUSplitMode SplitMode { get; }
  39. /// <summary>
  40. /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
  41. /// </summary>
  42. int GpuLayerCount { get; }
  43. /// <summary>
  44. /// Use mmap for faster loads (use_mmap)
  45. /// </summary>
  46. bool UseMemorymap { get; }
  47. /// <summary>
  48. /// Use mlock to keep model in memory (use_mlock)
  49. /// </summary>
  50. bool UseMemoryLock { get; }
  51. /// <summary>
  52. /// Model path (model)
  53. /// </summary>
  54. string ModelPath { get; }
  55. /// <summary>
  56. /// how split tensors should be distributed across GPUs
  57. /// </summary>
  58. TensorSplitsCollection TensorSplits { get; }
  59. /// <summary>
  60. /// Load vocab only (no weights)
  61. /// </summary>
  62. bool VocabOnly { get; }
  63. /// <summary>
  64. /// List of LoRA adapters to apply
  65. /// </summary>
  66. AdapterCollection LoraAdapters { get; }
  67. /// <summary>
  68. /// base model path for the lora adapter (lora_base)
  69. /// </summary>
  70. string LoraBase { get; }
  71. /// <summary>
  72. /// Override specific metadata items in the model
  73. /// </summary>
  74. List<MetadataOverride> MetadataOverrides { get; }
  75. }
  76. /// <summary>
  77. /// A LoRA adapter to apply to a model
  78. /// </summary>
  79. /// <param name="Path">Path to the LoRA file</param>
  80. /// <param name="Scale">Strength of this LoRA</param>
  81. public readonly record struct LoraAdapter(string Path, float Scale);
  82. /// <summary>
  83. /// A list of LoraAdapter objects
  84. /// </summary>
  85. public sealed class AdapterCollection
  86. : List<LoraAdapter>, IEquatable<AdapterCollection>
  87. {
  88. /// <inheritdoc />
  89. public bool Equals(AdapterCollection? other)
  90. {
  91. if (other == null)
  92. return false;
  93. return this.SequenceEqual(other);
  94. }
  95. /// <inheritdoc/>
  96. public override bool Equals(object? obj)
  97. {
  98. return Equals(obj as AdapterCollection);
  99. }
  100. /// <inheritdoc/>
  101. public override int GetHashCode()
  102. {
  103. unchecked
  104. {
  105. var hash = 17;
  106. for (var i = 0; i < Count; i++)
  107. {
  108. hash += this[i].GetHashCode();
  109. hash *= 7823;
  110. }
  111. return hash;
  112. }
  113. }
  114. }
  115. /// <summary>
  116. /// A fixed size array to set the tensor splits across multiple GPUs
  117. /// </summary>
  118. [JsonConverter(typeof(TensorSplitsCollectionConverter))]
  119. public sealed class TensorSplitsCollection
  120. : IEnumerable<float>
  121. {
  122. internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
  123. /// <summary>
  124. /// The size of this array
  125. /// </summary>
  126. public int Length => Splits.Length;
  127. /// <summary>
  128. /// Get or set the proportion of work to do on the given device.
  129. /// </summary>
  130. /// <remarks>"[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.</remarks>
  131. /// <param name="index"></param>
  132. /// <returns></returns>
  133. public float this[int index]
  134. {
  135. get => Splits[index];
  136. set => Splits[index] = value;
  137. }
  138. /// <summary>
  139. /// Create a new tensor splits collection, copying the given values
  140. /// </summary>
  141. /// <param name="splits"></param>
  142. /// <exception cref="ArgumentException"></exception>
  143. public TensorSplitsCollection(float[] splits)
  144. {
  145. if (splits.Length > Splits.Length)
  146. throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
  147. splits.CopyTo(Splits.AsSpan());
  148. }
  149. /// <summary>
  150. /// Create a new tensor splits collection with all values initialised to the default
  151. /// </summary>
  152. public TensorSplitsCollection()
  153. {
  154. }
  155. /// <summary>
  156. /// Set all values to zero
  157. /// </summary>
  158. public void Clear()
  159. {
  160. Array.Clear(Splits, 0, Splits.Length);
  161. }
  162. internal MemoryHandle Pin()
  163. {
  164. return Splits.AsMemory().Pin();
  165. }
  166. #region IEnumerator
  167. /// <inheritdoc />
  168. public IEnumerator<float> GetEnumerator()
  169. {
  170. return ((IEnumerable<float>)Splits).GetEnumerator();
  171. }
  172. /// <inheritdoc />
  173. IEnumerator IEnumerable.GetEnumerator()
  174. {
  175. return Splits.GetEnumerator();
  176. }
  177. #endregion
  178. }
  179. /// <summary>
  180. /// A JSON converter for <see cref="TensorSplitsCollection"/>
  181. /// </summary>
  182. public class TensorSplitsCollectionConverter
  183. : JsonConverter<TensorSplitsCollection>
  184. {
  185. /// <inheritdoc/>
  186. public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  187. {
  188. var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
  189. return new TensorSplitsCollection(arr);
  190. }
  191. /// <inheritdoc/>
  192. public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
  193. {
  194. JsonSerializer.Serialize(writer, value.Splits, options);
  195. }
  196. }
  197. /// <summary>
  198. /// An override for a single key/value pair in model metadata
  199. /// </summary>
  200. [JsonConverter(typeof(MetadataOverrideConverter))]
  201. public sealed record MetadataOverride
  202. {
  203. /// <summary>
  204. /// Get the key being overridden by this override
  205. /// </summary>
  206. public string Key { get; }
  207. internal LLamaModelKvOverrideType Type { get; }
  208. private readonly int _valueInt;
  209. private readonly float _valueFloat;
  210. private readonly bool _valueBool;
  211. /// <summary>
  212. /// Create a new override for an int key
  213. /// </summary>
  214. /// <param name="key"></param>
  215. /// <param name="value"></param>
  216. public MetadataOverride(string key, int value)
  217. {
  218. Key = key;
  219. _valueInt = value;
  220. Type = LLamaModelKvOverrideType.Int;
  221. }
  222. /// <summary>
  223. /// Create a new override for a float key
  224. /// </summary>
  225. /// <param name="key"></param>
  226. /// <param name="value"></param>
  227. public MetadataOverride(string key, float value)
  228. {
  229. Key = key;
  230. _valueFloat = value;
  231. Type = LLamaModelKvOverrideType.Float;
  232. }
  233. /// <summary>
  234. /// Create a new override for a boolean key
  235. /// </summary>
  236. /// <param name="key"></param>
  237. /// <param name="value"></param>
  238. public MetadataOverride(string key, bool value)
  239. {
  240. Key = key;
  241. _valueBool = value;
  242. Type = LLamaModelKvOverrideType.Bool;
  243. }
  244. internal void WriteValue(ref LLamaModelMetadataOverride dest)
  245. {
  246. switch (Type)
  247. {
  248. case LLamaModelKvOverrideType.Int:
  249. dest.IntValue = _valueInt;
  250. break;
  251. case LLamaModelKvOverrideType.Float:
  252. dest.FloatValue = _valueFloat;
  253. break;
  254. case LLamaModelKvOverrideType.Bool:
  255. dest.BoolValue = _valueBool ? -1L : 0;
  256. break;
  257. default:
  258. throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
  259. }
  260. }
  261. internal void WriteValue(Utf8JsonWriter writer)
  262. {
  263. switch (Type)
  264. {
  265. case LLamaModelKvOverrideType.Int:
  266. writer.WriteNumberValue(_valueInt);
  267. break;
  268. case LLamaModelKvOverrideType.Float:
  269. writer.WriteNumberValue(_valueFloat);
  270. break;
  271. case LLamaModelKvOverrideType.Bool:
  272. writer.WriteBooleanValue(_valueBool);
  273. break;
  274. default:
  275. throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
  276. }
  277. }
  278. }
  279. /// <summary>
  280. /// A JSON converter for <see cref="MetadataOverride"/>
  281. /// </summary>
  282. public class MetadataOverrideConverter
  283. : JsonConverter<MetadataOverride>
  284. {
  285. /// <inheritdoc/>
  286. public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  287. {
  288. var ktv = JsonSerializer.Deserialize<KeyTypeValue>(ref reader, options)!;
  289. return ((LLamaModelKvOverrideType)ktv.Type) switch
  290. {
  291. LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
  292. LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
  293. LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
  294. _ => throw new JsonException(),
  295. };
  296. }
  297. /// <inheritdoc/>
  298. public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
  299. {
  300. writer.WriteStartObject();
  301. {
  302. writer.WriteNumber("Type", (int)value.Type);
  303. writer.WriteString("Key", value.Key);
  304. writer.WritePropertyName("Value");
  305. value.WriteValue(writer);
  306. }
  307. writer.WriteEndObject();
  308. }
  309. private record KeyTypeValue(int Type, string Key, JsonElement Value);
  310. }
  311. }