From 93f24f8a51161ede447f1f6f312ea83f009ff695 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 24 Aug 2023 00:09:00 +0100 Subject: [PATCH] Switched to properly typed `Encoding` property --- LLama.Unittest/LLama.Unittest.csproj | 8 ++--- LLama.Unittest/ModelsParamsTests.cs | 53 +++++++++++++++++++++++++--- LLama.Web/Common/ModelOptions.cs | 5 +-- LLama/Abstractions/IModelParams.cs | 6 ++-- LLama/Common/ModelParams.cs | 5 +-- LLama/LLamaContext.cs | 6 ++-- LLama/LLamaStatelessExecutor.cs | 2 +- LLama/LLamaWeights.cs | 2 +- 8 files changed, 67 insertions(+), 20 deletions(-) diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 81e71a88..03b865aa 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -11,13 +11,13 @@ - - - + + + runtime; build; native; contentfiles; analyzers; buildtransitive all - + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index 3296ca09..0d657e1c 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -1,4 +1,6 @@ -using LLama.Common; +using System.Text; +using LLama.Common; +using Newtonsoft.Json; namespace LLama.Unittest { @@ -17,12 +19,33 @@ namespace LLama.Unittest GpuLayerCount = 111 }; - var json = System.Text.Json.JsonSerializer.Serialize(expected); - var actual = System.Text.Json.JsonSerializer.Deserialize(json); + var options = new System.Text.Json.JsonSerializerOptions(); + options.Converters.Add(new SystemTextJsonEncodingConverter()); + + var json = System.Text.Json.JsonSerializer.Serialize(expected, options); + var actual = System.Text.Json.JsonSerializer.Deserialize(json, options); Assert.Equal(expected, actual); } + private class SystemTextJsonEncodingConverter + : System.Text.Json.Serialization.JsonConverter + + { + public override Encoding? Read(ref System.Text.Json.Utf8JsonReader reader, Type typeToConvert, System.Text.Json.JsonSerializerOptions options) + { + var name = reader.GetString(); + if (name == null) + return null; + return Encoding.GetEncoding(name); + } + + public override void Write(System.Text.Json.Utf8JsonWriter writer, Encoding value, System.Text.Json.JsonSerializerOptions options) + { + writer.WriteStringValue(value.WebName); + } + } + [Fact] public void SerializeRoundTripNewtonsoft() { @@ -36,10 +59,30 @@ namespace LLama.Unittest GpuLayerCount = 111 }; - var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected); - var actual = Newtonsoft.Json.JsonConvert.DeserializeObject(json); + var settings = new Newtonsoft.Json.JsonSerializerSettings(); + settings.Converters.Add(new NewtsonsoftEncodingConverter()); + + var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings); + var actual = Newtonsoft.Json.JsonConvert.DeserializeObject(json, settings); Assert.Equal(expected, actual); } + + private class NewtsonsoftEncodingConverter + : Newtonsoft.Json.JsonConverter + { + public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer) + { + writer.WriteValue((string?)value?.WebName); + } + + public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer) + { + var name = (string?)reader.Value; + if (name == null) + return null; + return Encoding.GetEncoding(name); + } + } } } diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 3f5a3f0c..9a432858 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -1,4 +1,5 @@ -using LLama.Abstractions; +using System.Text; +using LLama.Abstractions; namespace LLama.Web.Common { @@ -115,6 +116,6 @@ namespace LLama.Web.Common /// /// The encoding to use for models /// - public string Encoding { get; set; } = "UTF-8"; + public Encoding Encoding { get; set; } = Encoding.UTF8; } } diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 64a0125b..3ed4a84f 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,4 +1,6 @@ -namespace LLama.Abstractions +using System.Text; + +namespace LLama.Abstractions { public interface IModelParams { @@ -121,6 +123,6 @@ /// /// The encoding to use for models /// - string Encoding { get; set; } + Encoding Encoding { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 2b23a37e..9606feb3 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using System; +using System.Text; namespace LLama.Common { @@ -114,7 +115,7 @@ namespace LLama.Common /// /// The encoding to use to convert text for the model /// - public string Encoding { get; set; } = "UTF-8"; + public Encoding Encoding { get; set; } = Encoding.UTF8; /// /// @@ -183,7 +184,7 @@ namespace LLama.Common RopeFrequencyBase = ropeFrequencyBase; RopeFrequencyScale = ropeFrequencyScale; MulMatQ = mulMatQ; - Encoding = encoding; + Encoding = Encoding.GetEncoding(encoding); } } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9501d570..624b6964 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -68,7 +68,7 @@ namespace LLama Params = @params; _logger = logger; - _encoding = Encoding.GetEncoding(@params.Encoding); + _encoding = @params.Encoding; _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); _ctx = Utils.InitLLamaContextFromModelParams(Params); @@ -79,7 +79,7 @@ namespace LLama Params = @params; _logger = logger; - _encoding = Encoding.GetEncoding(@params.Encoding); + _encoding = @params.Encoding; _ctx = nativeContext; } @@ -98,7 +98,7 @@ namespace LLama Params = @params; _logger = logger; - _encoding = Encoding.GetEncoding(@params.Encoding); + _encoding = @params.Encoding; using var pin = @params.ToLlamaContextParams(out var lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 446571b0..d3f0c0e2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -47,7 +47,7 @@ namespace LLama [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] public StatelessExecutor(LLamaContext context) { - _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding)); + _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); _params = context.Params; Context = _weights.CreateContext(_params); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 8997e9c4..1b067f1b 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -59,7 +59,7 @@ namespace LLama if (!string.IsNullOrEmpty(@params.LoraAdapter)) weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding)); + return new LLamaWeights(weights, @params.Encoding); } ///