| @@ -11,13 +11,13 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
| <PackageReference Include="xunit" Version="2.4.2" /> | |||||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.4.5"> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | |||||
| <PackageReference Include="xunit" Version="2.5.0" /> | |||||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | |||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| <PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
| </PackageReference> | </PackageReference> | ||||
| <PackageReference Include="coverlet.collector" Version="3.1.2"> | |||||
| <PackageReference Include="coverlet.collector" Version="6.0.0"> | |||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| <PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
| </PackageReference> | </PackageReference> | ||||
| @@ -1,4 +1,6 @@ | |||||
| using LLama.Common; | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using Newtonsoft.Json; | |||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| { | { | ||||
| @@ -17,12 +19,33 @@ namespace LLama.Unittest | |||||
| GpuLayerCount = 111 | GpuLayerCount = 111 | ||||
| }; | }; | ||||
| var json = System.Text.Json.JsonSerializer.Serialize(expected); | |||||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json); | |||||
| var options = new System.Text.Json.JsonSerializerOptions(); | |||||
| options.Converters.Add(new SystemTextJsonEncodingConverter()); | |||||
| var json = System.Text.Json.JsonSerializer.Serialize(expected, options); | |||||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json, options); | |||||
| Assert.Equal(expected, actual); | Assert.Equal(expected, actual); | ||||
| } | } | ||||
| private class SystemTextJsonEncodingConverter | |||||
| : System.Text.Json.Serialization.JsonConverter<Encoding> | |||||
| { | |||||
| public override Encoding? Read(ref System.Text.Json.Utf8JsonReader reader, Type typeToConvert, System.Text.Json.JsonSerializerOptions options) | |||||
| { | |||||
| var name = reader.GetString(); | |||||
| if (name == null) | |||||
| return null; | |||||
| return Encoding.GetEncoding(name); | |||||
| } | |||||
| public override void Write(System.Text.Json.Utf8JsonWriter writer, Encoding value, System.Text.Json.JsonSerializerOptions options) | |||||
| { | |||||
| writer.WriteStringValue(value.WebName); | |||||
| } | |||||
| } | |||||
| [Fact] | [Fact] | ||||
| public void SerializeRoundTripNewtonsoft() | public void SerializeRoundTripNewtonsoft() | ||||
| { | { | ||||
| @@ -36,10 +59,30 @@ namespace LLama.Unittest | |||||
| GpuLayerCount = 111 | GpuLayerCount = 111 | ||||
| }; | }; | ||||
| var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected); | |||||
| var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json); | |||||
| var settings = new Newtonsoft.Json.JsonSerializerSettings(); | |||||
| settings.Converters.Add(new NewtsonsoftEncodingConverter()); | |||||
| var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings); | |||||
| var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json, settings); | |||||
| Assert.Equal(expected, actual); | Assert.Equal(expected, actual); | ||||
| } | } | ||||
| private class NewtsonsoftEncodingConverter | |||||
| : Newtonsoft.Json.JsonConverter<Encoding> | |||||
| { | |||||
| public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer) | |||||
| { | |||||
| writer.WriteValue((string?)value?.WebName); | |||||
| } | |||||
| public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer) | |||||
| { | |||||
| var name = (string?)reader.Value; | |||||
| if (name == null) | |||||
| return null; | |||||
| return Encoding.GetEncoding(name); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Abstractions; | |||||
| using System.Text; | |||||
| using LLama.Abstractions; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| @@ -115,6 +116,6 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// The encoding to use for models | /// The encoding to use for models | ||||
| /// </summary> | /// </summary> | ||||
| public string Encoding { get; set; } = "UTF-8"; | |||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,6 @@ | |||||
| namespace LLama.Abstractions | |||||
| using System.Text; | |||||
| namespace LLama.Abstractions | |||||
| { | { | ||||
| public interface IModelParams | public interface IModelParams | ||||
| { | { | ||||
| @@ -121,6 +123,6 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// The encoding to use for models | /// The encoding to use for models | ||||
| /// </summary> | /// </summary> | ||||
| string Encoding { get; set; } | |||||
| Encoding Encoding { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using System; | using System; | ||||
| using System.Text; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -114,7 +115,7 @@ namespace LLama.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// The encoding to use to convert text for the model | /// The encoding to use to convert text for the model | ||||
| /// </summary> | /// </summary> | ||||
| public string Encoding { get; set; } = "UTF-8"; | |||||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -183,7 +184,7 @@ namespace LLama.Common | |||||
| RopeFrequencyBase = ropeFrequencyBase; | RopeFrequencyBase = ropeFrequencyBase; | ||||
| RopeFrequencyScale = ropeFrequencyScale; | RopeFrequencyScale = ropeFrequencyScale; | ||||
| MulMatQ = mulMatQ; | MulMatQ = mulMatQ; | ||||
| Encoding = encoding; | |||||
| Encoding = Encoding.GetEncoding(encoding); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ namespace LLama | |||||
| Params = @params; | Params = @params; | ||||
| _logger = logger; | _logger = logger; | ||||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||||
| _encoding = @params.Encoding; | |||||
| _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); | _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); | ||||
| _ctx = Utils.InitLLamaContextFromModelParams(Params); | _ctx = Utils.InitLLamaContextFromModelParams(Params); | ||||
| @@ -79,7 +79,7 @@ namespace LLama | |||||
| Params = @params; | Params = @params; | ||||
| _logger = logger; | _logger = logger; | ||||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||||
| _encoding = @params.Encoding; | |||||
| _ctx = nativeContext; | _ctx = nativeContext; | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ namespace LLama | |||||
| Params = @params; | Params = @params; | ||||
| _logger = logger; | _logger = logger; | ||||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||||
| _encoding = @params.Encoding; | |||||
| using var pin = @params.ToLlamaContextParams(out var lparams); | using var pin = @params.ToLlamaContextParams(out var lparams); | ||||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | ||||
| @@ -47,7 +47,7 @@ namespace LLama | |||||
| [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] | [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] | ||||
| public StatelessExecutor(LLamaContext context) | public StatelessExecutor(LLamaContext context) | ||||
| { | { | ||||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding)); | |||||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); | |||||
| _params = context.Params; | _params = context.Params; | ||||
| Context = _weights.CreateContext(_params); | Context = _weights.CreateContext(_params); | ||||
| @@ -59,7 +59,7 @@ namespace LLama | |||||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | if (!string.IsNullOrEmpty(@params.LoraAdapter)) | ||||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | ||||
| return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding)); | |||||
| return new LLamaWeights(weights, @params.Encoding); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||