|
|
|
@@ -5,7 +5,6 @@ using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text.Json; |
|
|
|
using System.Text.Json.Serialization; |
|
|
|
using LLama.Common; |
|
|
|
using LLama.Native; |
|
|
|
|
|
|
|
namespace LLama.Abstractions |
|
|
|
@@ -110,6 +109,7 @@ namespace LLama.Abstractions |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// A fixed size array to set the tensor splits across multiple GPUs |
|
|
|
/// </summary> |
|
|
|
@@ -204,6 +204,7 @@ namespace LLama.Abstractions |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// An override for a single key/value pair in model metadata |
|
|
|
/// </summary> |
|
|
|
@@ -243,57 +244,92 @@ namespace LLama.Abstractions |
|
|
|
return new BoolOverride(key, value); |
|
|
|
} |
|
|
|
|
|
|
|
internal abstract void Write(ref LLamaModelMetadataOverride dest); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the key being overriden by this override |
|
|
|
/// </summary> |
|
|
|
public abstract string Key { get; init; } |
|
|
|
|
|
|
|
internal abstract LLamaModelKvOverrideType Type { get; } |
|
|
|
|
|
|
|
internal abstract void WriteValue(ref LLamaModelMetadataOverride dest); |
|
|
|
|
|
|
|
internal abstract void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options); |
|
|
|
|
|
|
|
private record IntOverride(string Key, int Value) : MetadataOverride |
|
|
|
{ |
|
|
|
internal override void Write(ref LLamaModelMetadataOverride dest) |
|
|
|
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; |
|
|
|
|
|
|
|
internal override void WriteValue(ref LLamaModelMetadataOverride dest) |
|
|
|
{ |
|
|
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; |
|
|
|
dest.IntValue = Value; |
|
|
|
} |
|
|
|
|
|
|
|
internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
writer.WriteNumberValue(Value); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private record FloatOverride(string Key, float Value) : MetadataOverride |
|
|
|
{ |
|
|
|
internal override void Write(ref LLamaModelMetadataOverride dest) |
|
|
|
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; |
|
|
|
|
|
|
|
internal override void WriteValue(ref LLamaModelMetadataOverride dest) |
|
|
|
{ |
|
|
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; |
|
|
|
dest.FloatValue = Value; |
|
|
|
} |
|
|
|
|
|
|
|
internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
writer.WriteNumberValue(Value); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private record BoolOverride(string Key, bool Value) : MetadataOverride |
|
|
|
{ |
|
|
|
internal override void Write(ref LLamaModelMetadataOverride dest) |
|
|
|
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; |
|
|
|
|
|
|
|
internal override void WriteValue(ref LLamaModelMetadataOverride dest) |
|
|
|
{ |
|
|
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; |
|
|
|
dest.BoolValue = Value ? -1 : 0; |
|
|
|
} |
|
|
|
|
|
|
|
internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
writer.WriteBooleanValue(Value); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// A JSON converter for <see cref="MetadataOverride"/> |
|
|
|
/// </summary> |
|
|
|
public class MetadataOverrideConverter |
|
|
|
: JsonConverter<MetadataOverride> |
|
|
|
{ |
|
|
|
/// <inheritdoc/> |
|
|
|
public override bool CanConvert(Type typeToConvert) |
|
|
|
{ |
|
|
|
return typeof(MetadataOverride).IsAssignableFrom(typeToConvert); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc/> |
|
|
|
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); |
|
|
|
//return new TensorSplitsCollection(arr); |
|
|
|
throw new NotImplementedException("for some reason this is never called!"); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc/> |
|
|
|
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
//JsonSerializer.Serialize(writer, value.Splits, options); |
|
|
|
writer.WriteStartObject(); |
|
|
|
{ |
|
|
|
writer.WriteString("Key", value.Key); |
|
|
|
writer.WriteNumber("Type", (int)value.Type); |
|
|
|
writer.WritePropertyName("Value"); |
|
|
|
value.WriteValue(writer, options); |
|
|
|
} |
|
|
|
writer.WriteEndObject(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |