using System; using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using LLama.Common; using LLama.Native; namespace LLama.Abstractions { /// /// The parameters for initializing a LLama model. /// public interface IModelParams { /// /// the GPU that is used for scratch and small tensors /// int MainGpu { get; set; } /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// int GpuLayerCount { get; set; } /// /// Use mmap for faster loads (use_mmap) /// bool UseMemorymap { get; set; } /// /// Use mlock to keep model in memory (use_mlock) /// bool UseMemoryLock { get; set; } /// /// Model path (model) /// string ModelPath { get; set; } /// /// how split tensors should be distributed across GPUs /// TensorSplitsCollection TensorSplits { get; set; } /// /// Load vocab only (no weights) /// bool VocabOnly { get; set; } /// /// List of LoRA adapters to apply /// AdapterCollection LoraAdapters { get; } /// /// base model path for the lora adapter (lora_base) /// string LoraBase { get; set; } /// /// Override specific metadata items in the model /// List MetadataOverrides { get; } } /// /// A LoRA adapter to apply to a model /// /// Path to the LoRA file /// Strength of this LoRA public readonly record struct LoraAdapter(string Path, float Scale); /// /// A list of LoraAdapter objects /// public sealed class AdapterCollection : List, IEquatable { /// public bool Equals(AdapterCollection? other) { if (other == null) return false; return this.SequenceEqual(other); } /// public override bool Equals(object? obj) { return Equals(obj as AdapterCollection); } /// public override int GetHashCode() { unchecked { var hash = 17; for (var i = 0; i < Count; i++) { hash += this[i].GetHashCode(); hash *= 7823; } return hash; } } } /// /// A fixed size array to set the tensor splits across multiple GPUs /// [JsonConverter(typeof(TensorSplitsCollectionConverter))] public sealed class TensorSplitsCollection : IEnumerable { internal readonly float[] Splits = new float[NativeApi.llama_max_devices()]; /// /// The size of this array /// public int Length => Splits.Length; /// /// Get or set the proportion of work to do on the given device. /// /// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1. /// /// public float this[int index] { get => Splits[index]; set => Splits[index] = value; } /// /// Create a new tensor splits collection, copying the given values /// /// /// public TensorSplitsCollection(float[] splits) { if (splits.Length > Splits.Length) throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits)); splits.CopyTo(Splits.AsSpan()); } /// /// Create a new tensor splits collection with all values initialised to the default /// public TensorSplitsCollection() { } /// /// Set all values to zero /// public void Clear() { Array.Clear(Splits, 0, Splits.Length); } internal MemoryHandle Pin() { return Splits.AsMemory().Pin(); } #region IEnumerator /// public IEnumerator GetEnumerator() { return ((IEnumerable)Splits).GetEnumerator(); } /// IEnumerator IEnumerable.GetEnumerator() { return Splits.GetEnumerator(); } #endregion } /// /// A JSON converter for /// public class TensorSplitsCollectionConverter : JsonConverter { /// public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); return new TensorSplitsCollection(arr); } /// public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) { JsonSerializer.Serialize(writer, value.Splits, options); } } /// /// An override for a single key/value pair in model metadata /// [JsonConverter(typeof(MetadataOverrideConverter))] public abstract record MetadataOverride { /// /// Create a new override for an int key /// /// /// /// public static MetadataOverride Create(string key, int value) { return new IntOverride(key, value); } /// /// Create a new override for a float key /// /// /// /// public static MetadataOverride Create(string key, float value) { return new FloatOverride(key, value); } /// /// Create a new override for a boolean key /// /// /// /// public static MetadataOverride Create(string key, bool value) { return new BoolOverride(key, value); } internal abstract void Write(ref LLamaModelMetadataOverride dest); /// /// Get the key being overriden by this override /// public abstract string Key { get; init; } private record IntOverride(string Key, int Value) : MetadataOverride { internal override void Write(ref LLamaModelMetadataOverride dest) { dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; dest.IntValue = Value; } } private record FloatOverride(string Key, float Value) : MetadataOverride { internal override void Write(ref LLamaModelMetadataOverride dest) { dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; dest.FloatValue = Value; } } private record BoolOverride(string Key, bool Value) : MetadataOverride { internal override void Write(ref LLamaModelMetadataOverride dest) { dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; dest.BoolValue = Value ? -1 : 0; } } } public class MetadataOverrideConverter : JsonConverter { /// public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { throw new NotImplementedException(); //var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); //return new TensorSplitsCollection(arr); } /// public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options) { throw new NotImplementedException(); //JsonSerializer.Serialize(writer, value.Splits, options); } } }