Browse Source

Implemented serialization for `MetadataOverride`. Deserialization is broken (converter is never called)

tags/0.9.1
Martin Evans 1 year ago
parent
commit
2f0deeadcd
2 changed files with 55 additions and 17 deletions
  1. +50
    -14
      LLama/Abstractions/IModelParams.cs
  2. +5
    -3
      LLama/Extensions/IModelParamsExtensions.cs

+ 50
- 14
LLama/Abstractions/IModelParams.cs View File

@@ -5,7 +5,6 @@ using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Common;
using LLama.Native;

namespace LLama.Abstractions
@@ -110,6 +109,7 @@ namespace LLama.Abstractions
}
}


/// <summary>
/// A fixed size array to set the tensor splits across multiple GPUs
/// </summary>
@@ -204,6 +204,7 @@ namespace LLama.Abstractions
}
}


/// <summary>
/// An override for a single key/value pair in model metadata
/// </summary>
@@ -243,57 +244,92 @@ namespace LLama.Abstractions
return new BoolOverride(key, value);
}

internal abstract void Write(ref LLamaModelMetadataOverride dest);

/// <summary>
/// Get the key being overriden by this override
/// </summary>
public abstract string Key { get; init; }

internal abstract LLamaModelKvOverrideType Type { get; }

internal abstract void WriteValue(ref LLamaModelMetadataOverride dest);

internal abstract void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options);

private record IntOverride(string Key, int Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;

internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}

internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
{
writer.WriteNumberValue(Value);
}
}

private record FloatOverride(string Key, float Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;

internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}

internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
{
writer.WriteNumberValue(Value);
}
}

private record BoolOverride(string Key, bool Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;

internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}

internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
{
writer.WriteBooleanValue(Value);
}
}
}

/// <summary>
/// A JSON converter for <see cref="MetadataOverride"/>
/// </summary>
public class MetadataOverrideConverter
: JsonConverter<MetadataOverride>
{
/// <inheritdoc/>
public override bool CanConvert(Type typeToConvert)
{
return typeof(MetadataOverride).IsAssignableFrom(typeToConvert);
}

/// <inheritdoc/>
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
//return new TensorSplitsCollection(arr);
throw new NotImplementedException("for some reason this is never called!");
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
throw new NotImplementedException();
//JsonSerializer.Serialize(writer, value.Splits, options);
writer.WriteStartObject();
{
writer.WriteString("Key", value.Key);
writer.WriteNumber("Type", (int)value.Type);
writer.WritePropertyName("Value");
value.WriteValue(writer, options);
}
writer.WriteEndObject();
}
}
}

+ 5
- 3
LLama/Extensions/IModelParamsExtensions.cs View File

@@ -57,10 +57,12 @@ public static class IModelParamsExtensions
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride();
var native = new LLamaModelMetadataOverride
{
Tag = item.Type
};

// Init value and tag
item.Write(ref native);
item.WriteValue(ref native);

// Convert key to bytes
unsafe


Loading…
Cancel
Save