using System; using System.Buffers; using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using LLama.Native; namespace LLama.Abstractions { /// /// The parameters for initializing a LLama model. /// public interface IModelParams { /// /// main_gpu interpretation depends on split_mode: /// /// /// None /// The GPU that is used for the entire mode. /// /// /// Row /// The GPU that is used for small tensors and intermediate results. /// /// /// Layer /// Ignored. /// /// /// int MainGpu { get; set; } /// /// How to split the model across multiple GPUs /// GPUSplitMode SplitMode { get; } /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// int GpuLayerCount { get; } /// /// Use mmap for faster loads (use_mmap) /// bool UseMemorymap { get; } /// /// Use mlock to keep model in memory (use_mlock) /// bool UseMemoryLock { get; } /// /// Model path (model) /// string ModelPath { get; } /// /// how split tensors should be distributed across GPUs /// TensorSplitsCollection TensorSplits { get; } /// /// Load vocab only (no weights) /// bool VocabOnly { get; } /// /// List of LoRA adapters to apply /// AdapterCollection LoraAdapters { get; } /// /// base model path for the lora adapter (lora_base) /// string LoraBase { get; } /// /// Override specific metadata items in the model /// List MetadataOverrides { get; } } /// /// A LoRA adapter to apply to a model /// /// Path to the LoRA file /// Strength of this LoRA public readonly record struct LoraAdapter(string Path, float Scale); /// /// A list of LoraAdapter objects /// public sealed class AdapterCollection : List, IEquatable { /// public bool Equals(AdapterCollection? other) { if (other == null) return false; return this.SequenceEqual(other); } /// public override bool Equals(object? obj) { return Equals(obj as AdapterCollection); } /// public override int GetHashCode() { unchecked { var hash = 17; for (var i = 0; i < Count; i++) { hash += this[i].GetHashCode(); hash *= 7823; } return hash; } } } /// /// A fixed size array to set the tensor splits across multiple GPUs /// [JsonConverter(typeof(TensorSplitsCollectionConverter))] public sealed class TensorSplitsCollection : IEnumerable { internal readonly float[] Splits = new float[NativeApi.llama_max_devices()]; /// /// The size of this array /// public int Length => Splits.Length; /// /// Get or set the proportion of work to do on the given device. /// /// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1. /// /// public float this[int index] { get => Splits[index]; set => Splits[index] = value; } /// /// Create a new tensor splits collection, copying the given values /// /// /// public TensorSplitsCollection(float[] splits) { if (splits.Length > Splits.Length) throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits)); splits.CopyTo(Splits.AsSpan()); } /// /// Create a new tensor splits collection with all values initialised to the default /// public TensorSplitsCollection() { } /// /// Set all values to zero /// public void Clear() { Array.Clear(Splits, 0, Splits.Length); } internal MemoryHandle Pin() { return Splits.AsMemory().Pin(); } #region IEnumerator /// public IEnumerator GetEnumerator() { return ((IEnumerable)Splits).GetEnumerator(); } /// IEnumerator IEnumerable.GetEnumerator() { return Splits.GetEnumerator(); } #endregion } /// /// A JSON converter for /// public class TensorSplitsCollectionConverter : JsonConverter { /// public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); return new TensorSplitsCollection(arr); } /// public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) { JsonSerializer.Serialize(writer, value.Splits, options); } } /// /// An override for a single key/value pair in model metadata /// [JsonConverter(typeof(MetadataOverrideConverter))] public sealed record MetadataOverride { /// /// Get the key being overridden by this override /// public string Key { get; } internal LLamaModelKvOverrideType Type { get; } private readonly int _valueInt; private readonly float _valueFloat; private readonly bool _valueBool; /// /// Create a new override for an int key /// /// /// public MetadataOverride(string key, int value) { Key = key; _valueInt = value; Type = LLamaModelKvOverrideType.Int; } /// /// Create a new override for a float key /// /// /// public MetadataOverride(string key, float value) { Key = key; _valueFloat = value; Type = LLamaModelKvOverrideType.Float; } /// /// Create a new override for a boolean key /// /// /// public MetadataOverride(string key, bool value) { Key = key; _valueBool = value; Type = LLamaModelKvOverrideType.Bool; } internal void WriteValue(ref LLamaModelMetadataOverride dest) { switch (Type) { case LLamaModelKvOverrideType.Int: dest.IntValue = _valueInt; break; case LLamaModelKvOverrideType.Float: dest.FloatValue = _valueFloat; break; case LLamaModelKvOverrideType.Bool: dest.BoolValue = _valueBool ? -1L : 0; break; default: throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}"); } } internal void WriteValue(Utf8JsonWriter writer) { switch (Type) { case LLamaModelKvOverrideType.Int: writer.WriteNumberValue(_valueInt); break; case LLamaModelKvOverrideType.Float: writer.WriteNumberValue(_valueFloat); break; case LLamaModelKvOverrideType.Bool: writer.WriteBooleanValue(_valueBool); break; default: throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}"); } } } /// /// A JSON converter for /// public class MetadataOverrideConverter : JsonConverter { /// public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var ktv = JsonSerializer.Deserialize(ref reader, options)!; return ((LLamaModelKvOverrideType)ktv.Type) switch { LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()), LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()), LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()), _ => throw new JsonException(), }; } /// public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options) { writer.WriteStartObject(); { writer.WriteNumber("Type", (int)value.Type); writer.WriteString("Key", value.Key); writer.WritePropertyName("Value"); value.WriteValue(writer); } writer.WriteEndObject(); } private record KeyTypeValue(int Type, string Key, JsonElement Value); } }