From 2f0deeadcd72acf4b17366ee3d7504251f7aabdf Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 14 Dec 2023 16:11:31 +0000 Subject: [PATCH] Implemented serialization for `MetadataOverride`. Deserialization is broken (converter is never called) --- LLama/Abstractions/IModelParams.cs | 64 +++++++++++++++++----- LLama/Extensions/IModelParamsExtensions.cs | 8 ++- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 56b7baee..fada91a1 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; -using LLama.Common; using LLama.Native; namespace LLama.Abstractions @@ -110,6 +109,7 @@ namespace LLama.Abstractions } } + /// /// A fixed size array to set the tensor splits across multiple GPUs /// @@ -204,6 +204,7 @@ namespace LLama.Abstractions } } + /// /// An override for a single key/value pair in model metadata /// @@ -243,57 +244,92 @@ namespace LLama.Abstractions return new BoolOverride(key, value); } - internal abstract void Write(ref LLamaModelMetadataOverride dest); - /// /// Get the key being overriden by this override /// public abstract string Key { get; init; } + internal abstract LLamaModelKvOverrideType Type { get; } + + internal abstract void WriteValue(ref LLamaModelMetadataOverride dest); + + internal abstract void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options); + private record IntOverride(string Key, int Value) : MetadataOverride { - internal override void Write(ref LLamaModelMetadataOverride dest) + internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; + + internal override void WriteValue(ref LLamaModelMetadataOverride dest) { - dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; dest.IntValue = Value; } + + internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) + { + writer.WriteNumberValue(Value); + } } private record FloatOverride(string Key, float Value) : MetadataOverride { - internal override void Write(ref LLamaModelMetadataOverride dest) + internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; + + internal override void WriteValue(ref LLamaModelMetadataOverride dest) { - dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; dest.FloatValue = Value; } + + internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) + { + writer.WriteNumberValue(Value); + } } private record BoolOverride(string Key, bool Value) : MetadataOverride { - internal override void Write(ref LLamaModelMetadataOverride dest) + internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; + + internal override void WriteValue(ref LLamaModelMetadataOverride dest) { - dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; dest.BoolValue = Value ? -1 : 0; } + + internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options) + { + writer.WriteBooleanValue(Value); + } } } + /// + /// A JSON converter for + /// public class MetadataOverrideConverter : JsonConverter { + /// + public override bool CanConvert(Type typeToConvert) + { + return typeof(MetadataOverride).IsAssignableFrom(typeToConvert); + } + /// public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - throw new NotImplementedException(); - //var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty(); - //return new TensorSplitsCollection(arr); + throw new NotImplementedException("for some reason this is never called!"); } /// public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options) { - throw new NotImplementedException(); - //JsonSerializer.Serialize(writer, value.Splits, options); + writer.WriteStartObject(); + { + writer.WriteString("Key", value.Key); + writer.WriteNumber("Type", (int)value.Type); + writer.WritePropertyName("Value"); + value.WriteValue(writer, options); + } + writer.WriteEndObject(); } } } \ No newline at end of file diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index f1a9dea9..08805d32 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -57,10 +57,12 @@ public static class IModelParamsExtensions for (var i = 0; i < @params.MetadataOverrides.Count; i++) { var item = @params.MetadataOverrides[i]; - var native = new LLamaModelMetadataOverride(); + var native = new LLamaModelMetadataOverride + { + Tag = item.Type + }; - // Init value and tag - item.Write(ref native); + item.WriteValue(ref native); // Convert key to bytes unsafe