using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
namespace LLama.Abstractions
{
///
/// The parameters for initializing a LLama model.
///
public interface IModelParams
{
///
/// main_gpu interpretation depends on split_mode:
///
/// -
/// None
/// The GPU that is used for the entire mode.
///
/// -
/// Row
/// The GPU that is used for small tensors and intermediate results.
///
/// -
/// Layer
/// Ignored.
///
///
///
int MainGpu { get; set; }
///
/// How to split the model across multiple GPUs
///
GPUSplitMode SplitMode { get; }
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
///
int GpuLayerCount { get; }
///
/// Use mmap for faster loads (use_mmap)
///
bool UseMemorymap { get; }
///
/// Use mlock to keep model in memory (use_mlock)
///
bool UseMemoryLock { get; }
///
/// Model path (model)
///
string ModelPath { get; }
///
/// how split tensors should be distributed across GPUs
///
TensorSplitsCollection TensorSplits { get; }
///
/// Load vocab only (no weights)
///
bool VocabOnly { get; }
///
/// List of LoRA adapters to apply
///
AdapterCollection LoraAdapters { get; }
///
/// base model path for the lora adapter (lora_base)
///
string LoraBase { get; }
///
/// Override specific metadata items in the model
///
List MetadataOverrides { get; }
}
///
/// A LoRA adapter to apply to a model
///
/// Path to the LoRA file
/// Strength of this LoRA
public readonly record struct LoraAdapter(string Path, float Scale);
///
/// A list of LoraAdapter objects
///
public sealed class AdapterCollection
: List, IEquatable
{
///
public bool Equals(AdapterCollection? other)
{
if (other == null)
return false;
return this.SequenceEqual(other);
}
///
public override bool Equals(object? obj)
{
return Equals(obj as AdapterCollection);
}
///
public override int GetHashCode()
{
unchecked
{
var hash = 17;
for (var i = 0; i < Count; i++)
{
hash += this[i].GetHashCode();
hash *= 7823;
}
return hash;
}
}
}
///
/// A fixed size array to set the tensor splits across multiple GPUs
///
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public sealed class TensorSplitsCollection
: IEnumerable
{
internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
///
/// The size of this array
///
public int Length => Splits.Length;
///
/// Get or set the proportion of work to do on the given device.
///
/// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
///
///
public float this[int index]
{
get => Splits[index];
set => Splits[index] = value;
}
///
/// Create a new tensor splits collection, copying the given values
///
///
///
public TensorSplitsCollection(float[] splits)
{
if (splits.Length > Splits.Length)
throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
splits.CopyTo(Splits.AsSpan());
}
///
/// Create a new tensor splits collection with all values initialised to the default
///
public TensorSplitsCollection()
{
}
///
/// Set all values to zero
///
public void Clear()
{
Array.Clear(Splits, 0, Splits.Length);
}
internal MemoryHandle Pin()
{
return Splits.AsMemory().Pin();
}
#region IEnumerator
///
public IEnumerator GetEnumerator()
{
return ((IEnumerable)Splits).GetEnumerator();
}
///
IEnumerator IEnumerable.GetEnumerator()
{
return Splits.GetEnumerator();
}
#endregion
}
///
/// A JSON converter for
///
public class TensorSplitsCollectionConverter
: JsonConverter
{
///
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty();
return new TensorSplitsCollection(arr);
}
///
public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
///
/// An override for a single key/value pair in model metadata
///
[JsonConverter(typeof(MetadataOverrideConverter))]
public sealed record MetadataOverride
{
///
/// Get the key being overridden by this override
///
public string Key { get; }
internal LLamaModelKvOverrideType Type { get; }
private readonly int _valueInt;
private readonly float _valueFloat;
private readonly bool _valueBool;
///
/// Create a new override for an int key
///
///
///
public MetadataOverride(string key, int value)
{
Key = key;
_valueInt = value;
Type = LLamaModelKvOverrideType.Int;
}
///
/// Create a new override for a float key
///
///
///
public MetadataOverride(string key, float value)
{
Key = key;
_valueFloat = value;
Type = LLamaModelKvOverrideType.Float;
}
///
/// Create a new override for a boolean key
///
///
///
public MetadataOverride(string key, bool value)
{
Key = key;
_valueBool = value;
Type = LLamaModelKvOverrideType.Bool;
}
internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
{
case LLamaModelKvOverrideType.Int:
dest.IntValue = _valueInt;
break;
case LLamaModelKvOverrideType.Float:
dest.FloatValue = _valueFloat;
break;
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
default:
throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
}
}
internal void WriteValue(Utf8JsonWriter writer)
{
switch (Type)
{
case LLamaModelKvOverrideType.Int:
writer.WriteNumberValue(_valueInt);
break;
case LLamaModelKvOverrideType.Float:
writer.WriteNumberValue(_valueFloat);
break;
case LLamaModelKvOverrideType.Bool:
writer.WriteBooleanValue(_valueBool);
break;
default:
throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
}
}
}
///
/// A JSON converter for
///
public class MetadataOverrideConverter
: JsonConverter
{
///
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var ktv = JsonSerializer.Deserialize(ref reader, options)!;
return ((LLamaModelKvOverrideType)ktv.Type) switch
{
LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
_ => throw new JsonException(),
};
}
///
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
writer.WriteStartObject();
{
writer.WriteNumber("Type", (int)value.Type);
writer.WriteString("Key", value.Key);
writer.WritePropertyName("Value");
value.WriteValue(writer);
}
writer.WriteEndObject();
}
private record KeyTypeValue(int Type, string Key, JsonElement Value);
}
}