using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Common;
using LLama.Native;
namespace LLama.Abstractions
{
///
/// The parameters for initializing a LLama model.
///
public interface IModelParams
{
///
/// the GPU that is used for scratch and small tensors
///
int MainGpu { get; set; }
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
///
int GpuLayerCount { get; set; }
///
/// Use mmap for faster loads (use_mmap)
///
bool UseMemorymap { get; set; }
///
/// Use mlock to keep model in memory (use_mlock)
///
bool UseMemoryLock { get; set; }
///
/// Model path (model)
///
string ModelPath { get; set; }
///
/// how split tensors should be distributed across GPUs
///
TensorSplitsCollection TensorSplits { get; set; }
///
/// Load vocab only (no weights)
///
bool VocabOnly { get; set; }
///
/// List of LoRA adapters to apply
///
AdapterCollection LoraAdapters { get; }
///
/// base model path for the lora adapter (lora_base)
///
string LoraBase { get; set; }
///
/// Override specific metadata items in the model
///
List MetadataOverrides { get; }
}
///
/// A LoRA adapter to apply to a model
///
/// Path to the LoRA file
/// Strength of this LoRA
public readonly record struct LoraAdapter(string Path, float Scale);
///
/// A list of LoraAdapter objects
///
public sealed class AdapterCollection
: List, IEquatable
{
///
public bool Equals(AdapterCollection? other)
{
if (other == null)
return false;
return this.SequenceEqual(other);
}
///
public override bool Equals(object? obj)
{
return Equals(obj as AdapterCollection);
}
///
public override int GetHashCode()
{
unchecked
{
var hash = 17;
for (var i = 0; i < Count; i++)
{
hash += this[i].GetHashCode();
hash *= 7823;
}
return hash;
}
}
}
///
/// A fixed size array to set the tensor splits across multiple GPUs
///
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public sealed class TensorSplitsCollection
: IEnumerable
{
internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];
///
/// The size of this array
///
public int Length => Splits.Length;
///
/// Get or set the proportion of work to do on the given device.
///
/// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
///
///
public float this[int index]
{
get => Splits[index];
set => Splits[index] = value;
}
///
/// Create a new tensor splits collection, copying the given values
///
///
///
public TensorSplitsCollection(float[] splits)
{
if (splits.Length > Splits.Length)
throw new ArgumentException($"Must supply at most {Splits.Length} tensor splits", nameof(splits));
splits.CopyTo(Splits.AsSpan());
}
///
/// Create a new tensor splits collection with all values initialised to the default
///
public TensorSplitsCollection()
{
}
///
/// Set all values to zero
///
public void Clear()
{
Array.Clear(Splits, 0, Splits.Length);
}
internal MemoryHandle Pin()
{
return Splits.AsMemory().Pin();
}
#region IEnumerator
///
public IEnumerator GetEnumerator()
{
return ((IEnumerable)Splits).GetEnumerator();
}
///
IEnumerator IEnumerable.GetEnumerator()
{
return Splits.GetEnumerator();
}
#endregion
}
///
/// A JSON converter for
///
public class TensorSplitsCollectionConverter
: JsonConverter
{
///
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty();
return new TensorSplitsCollection(arr);
}
///
public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
///
/// An override for a single key/value pair in model metadata
///
[JsonConverter(typeof(MetadataOverrideConverter))]
public abstract record MetadataOverride
{
///
/// Create a new override for an int key
///
///
///
///
public static MetadataOverride Create(string key, int value)
{
return new IntOverride(key, value);
}
///
/// Create a new override for a float key
///
///
///
///
public static MetadataOverride Create(string key, float value)
{
return new FloatOverride(key, value);
}
///
/// Create a new override for a boolean key
///
///
///
///
public static MetadataOverride Create(string key, bool value)
{
return new BoolOverride(key, value);
}
internal abstract void Write(ref LLamaModelMetadataOverride dest);
///
/// Get the key being overriden by this override
///
public abstract string Key { get; init; }
private record IntOverride(string Key, int Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}
}
private record FloatOverride(string Key, float Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}
}
private record BoolOverride(string Key, bool Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}
}
}
public class MetadataOverrideConverter
: JsonConverter
{
///
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
//var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty();
//return new TensorSplitsCollection(arr);
}
///
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
throw new NotImplementedException();
//JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}