You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IModelParams.cs 5.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. using System;
  2. using System.Buffers;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text.Json;
  7. using System.Text.Json.Serialization;
  8. using LLama.Common;
  9. using LLama.Native;
  10. namespace LLama.Abstractions
  11. {
  12. /// <summary>
  13. /// The parameters for initializing a LLama model.
  14. /// </summary>
  15. public interface IModelParams
  16. {
  17. /// <summary>
  18. /// the GPU that is used for scratch and small tensors
  19. /// </summary>
  20. int MainGpu { get; set; }
  21. /// <summary>
  22. /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
  23. /// </summary>
  24. int GpuLayerCount { get; set; }
  25. /// <summary>
  26. /// Use mmap for faster loads (use_mmap)
  27. /// </summary>
  28. bool UseMemorymap { get; set; }
  29. /// <summary>
  30. /// Use mlock to keep model in memory (use_mlock)
  31. /// </summary>
  32. bool UseMemoryLock { get; set; }
  33. /// <summary>
  34. /// Model path (model)
  35. /// </summary>
  36. string ModelPath { get; set; }
  37. /// <summary>
  38. /// how split tensors should be distributed across GPUs
  39. /// </summary>
  40. TensorSplitsCollection TensorSplits { get; set; }
  41. /// <summary>
  42. /// Load vocab only (no weights)
  43. /// </summary>
  44. bool VocabOnly { get; set; }
  45. /// <summary>
  46. /// List of LoRA adapters to apply
  47. /// </summary>
  48. AdapterCollection LoraAdapters { get; }
  49. /// <summary>
  50. /// base model path for the lora adapter (lora_base)
  51. /// </summary>
  52. string LoraBase { get; set; }
  53. }
  54. /// <summary>
  55. /// A LoRA adapter to apply to a model
  56. /// </summary>
  57. /// <param name="Path">Path to the LoRA file</param>
  58. /// <param name="Scale">Strength of this LoRA</param>
  59. public readonly record struct LoraAdapter(string Path, float Scale);
  60. /// <summary>
  61. /// A list of LoraAdapter objects
  62. /// </summary>
  63. public sealed class AdapterCollection
  64. : List<LoraAdapter>, IEquatable<AdapterCollection>
  65. {
  66. /// <inheritdoc />
  67. public bool Equals(AdapterCollection? other)
  68. {
  69. if (other == null)
  70. return false;
  71. return this.SequenceEqual(other);
  72. }
  73. /// <inheritdoc/>
  74. public override bool Equals(object? obj)
  75. {
  76. return Equals(obj as AdapterCollection);
  77. }
  78. /// <inheritdoc/>
  79. public override int GetHashCode()
  80. {
  81. unchecked
  82. {
  83. var hash = 17;
  84. for (var i = 0; i < Count; i++)
  85. {
  86. hash += this[i].GetHashCode();
  87. hash *= 7823;
  88. }
  89. return hash;
  90. }
  91. }
  92. }
  93. /// <summary>
  94. /// A fixed size array to set the tensor splits across multiple GPUs
  95. /// </summary>
  96. [JsonConverter(typeof(TensorSplitsCollectionConverter))]
  97. public sealed class TensorSplitsCollection
  98. : IEnumerable<float>
  99. {
  100. internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
  101. /// <summary>
  102. /// The size of this array
  103. /// </summary>
  104. public int Length => Splits.Length;
  105. /// <summary>
  106. /// Get or set the proportion of work to do on the given device.
  107. /// </summary>
  108. /// <remarks>"[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.</remarks>
  109. /// <param name="index"></param>
  110. /// <returns></returns>
  111. public float this[int index]
  112. {
  113. get => Splits[index];
  114. set => Splits[index] = value;
  115. }
  116. /// <summary>
  117. /// Create a new tensor splits collection, copying the given values
  118. /// </summary>
  119. /// <param name="splits"></param>
  120. /// <exception cref="ArgumentException"></exception>
  121. public TensorSplitsCollection(float[] splits)
  122. {
  123. if (splits.Length > Splits.Length)
  124. throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
  125. splits.CopyTo(Splits.AsSpan());
  126. }
  127. /// <summary>
  128. /// Create a new tensor splits collection with all values initialised to the default
  129. /// </summary>
  130. public TensorSplitsCollection()
  131. {
  132. }
  133. /// <summary>
  134. /// Set all values to zero
  135. /// </summary>
  136. public void Clear()
  137. {
  138. Array.Clear(Splits, 0, Splits.Length);
  139. }
  140. internal MemoryHandle Pin()
  141. {
  142. return Splits.AsMemory().Pin();
  143. }
  144. #region IEnumerator
  145. /// <inheritdoc />
  146. public IEnumerator<float> GetEnumerator()
  147. {
  148. return ((IEnumerable<float>)Splits).GetEnumerator();
  149. }
  150. /// <inheritdoc />
  151. IEnumerator IEnumerable.GetEnumerator()
  152. {
  153. return Splits.GetEnumerator();
  154. }
  155. #endregion
  156. }
  157. /// <summary>
  158. /// A JSON converter for <see cref="TensorSplitsCollection"/>
  159. /// </summary>
  160. public class TensorSplitsCollectionConverter
  161. : JsonConverter<TensorSplitsCollection>
  162. {
  163. /// <inheritdoc/>
  164. public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
  165. {
  166. var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
  167. return new TensorSplitsCollection(arr);
  168. }
  169. /// <inheritdoc/>
  170. public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
  171. {
  172. JsonSerializer.Serialize(writer, value.Splits, options);
  173. }
  174. }
  175. }