Browse Source

Switched to properly typed `Encoding` property

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
93f24f8a51
8 changed files with 67 additions and 20 deletions
  1. +4
    -4
      LLama.Unittest/LLama.Unittest.csproj
  2. +48
    -5
      LLama.Unittest/ModelsParamsTests.cs
  3. +3
    -2
      LLama.Web/Common/ModelOptions.cs
  4. +4
    -2
      LLama/Abstractions/IModelParams.cs
  5. +3
    -2
      LLama/Common/ModelParams.cs
  6. +3
    -3
      LLama/LLamaContext.cs
  7. +1
    -1
      LLama/LLamaStatelessExecutor.cs
  8. +1
    -1
      LLama/LLamaWeights.cs

+ 4
- 4
LLama.Unittest/LLama.Unittest.csproj View File

@@ -11,13 +11,13 @@
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="xunit" Version="2.4.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5">
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
</PackageReference> </PackageReference>
<PackageReference Include="coverlet.collector" Version="3.1.2">
<PackageReference Include="coverlet.collector" Version="6.0.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
</PackageReference> </PackageReference>


+ 48
- 5
LLama.Unittest/ModelsParamsTests.cs View File

@@ -1,4 +1,6 @@
using LLama.Common;
using System.Text;
using LLama.Common;
using Newtonsoft.Json;


namespace LLama.Unittest namespace LLama.Unittest
{ {
@@ -17,12 +19,33 @@ namespace LLama.Unittest
GpuLayerCount = 111 GpuLayerCount = 111
}; };


var json = System.Text.Json.JsonSerializer.Serialize(expected);
var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json);
var options = new System.Text.Json.JsonSerializerOptions();
options.Converters.Add(new SystemTextJsonEncodingConverter());

var json = System.Text.Json.JsonSerializer.Serialize(expected, options);
var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json, options);


Assert.Equal(expected, actual); Assert.Equal(expected, actual);
} }


private class SystemTextJsonEncodingConverter
: System.Text.Json.Serialization.JsonConverter<Encoding>

{
public override Encoding? Read(ref System.Text.Json.Utf8JsonReader reader, Type typeToConvert, System.Text.Json.JsonSerializerOptions options)
{
var name = reader.GetString();
if (name == null)
return null;
return Encoding.GetEncoding(name);
}

public override void Write(System.Text.Json.Utf8JsonWriter writer, Encoding value, System.Text.Json.JsonSerializerOptions options)
{
writer.WriteStringValue(value.WebName);
}
}

[Fact] [Fact]
public void SerializeRoundTripNewtonsoft() public void SerializeRoundTripNewtonsoft()
{ {
@@ -36,10 +59,30 @@ namespace LLama.Unittest
GpuLayerCount = 111 GpuLayerCount = 111
}; };


var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected);
var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json);
var settings = new Newtonsoft.Json.JsonSerializerSettings();
settings.Converters.Add(new NewtsonsoftEncodingConverter());

var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings);
var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json, settings);


Assert.Equal(expected, actual); Assert.Equal(expected, actual);
} }

private class NewtsonsoftEncodingConverter
: Newtonsoft.Json.JsonConverter<Encoding>
{
public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer)
{
writer.WriteValue((string?)value?.WebName);
}

public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer)
{
var name = (string?)reader.Value;
if (name == null)
return null;
return Encoding.GetEncoding(name);
}
}
} }
} }

+ 3
- 2
LLama.Web/Common/ModelOptions.cs View File

@@ -1,4 +1,5 @@
using LLama.Abstractions;
using System.Text;
using LLama.Abstractions;


namespace LLama.Web.Common namespace LLama.Web.Common
{ {
@@ -115,6 +116,6 @@ namespace LLama.Web.Common
/// <summary> /// <summary>
/// The encoding to use for models /// The encoding to use for models
/// </summary> /// </summary>
public string Encoding { get; set; } = "UTF-8";
public Encoding Encoding { get; set; } = Encoding.UTF8;
} }
} }

+ 4
- 2
LLama/Abstractions/IModelParams.cs View File

@@ -1,4 +1,6 @@
namespace LLama.Abstractions
using System.Text;

namespace LLama.Abstractions
{ {
public interface IModelParams public interface IModelParams
{ {
@@ -121,6 +123,6 @@
/// <summary> /// <summary>
/// The encoding to use for models /// The encoding to use for models
/// </summary> /// </summary>
string Encoding { get; set; }
Encoding Encoding { get; set; }
} }
} }

+ 3
- 2
LLama/Common/ModelParams.cs View File

@@ -1,5 +1,6 @@
using LLama.Abstractions; using LLama.Abstractions;
using System; using System;
using System.Text;


namespace LLama.Common namespace LLama.Common
{ {
@@ -114,7 +115,7 @@ namespace LLama.Common
/// <summary> /// <summary>
/// The encoding to use to convert text for the model /// The encoding to use to convert text for the model
/// </summary> /// </summary>
public string Encoding { get; set; } = "UTF-8";
public Encoding Encoding { get; set; } = Encoding.UTF8;


/// <summary> /// <summary>
/// ///
@@ -183,7 +184,7 @@ namespace LLama.Common
RopeFrequencyBase = ropeFrequencyBase; RopeFrequencyBase = ropeFrequencyBase;
RopeFrequencyScale = ropeFrequencyScale; RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ; MulMatQ = mulMatQ;
Encoding = encoding;
Encoding = Encoding.GetEncoding(encoding);
} }
} }
} }

+ 3
- 3
LLama/LLamaContext.cs View File

@@ -68,7 +68,7 @@ namespace LLama
Params = @params; Params = @params;


_logger = logger; _logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;


_logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
_ctx = Utils.InitLLamaContextFromModelParams(Params); _ctx = Utils.InitLLamaContextFromModelParams(Params);
@@ -79,7 +79,7 @@ namespace LLama
Params = @params; Params = @params;


_logger = logger; _logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;
_ctx = nativeContext; _ctx = nativeContext;
} }


@@ -98,7 +98,7 @@ namespace LLama
Params = @params; Params = @params;


_logger = logger; _logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;


using var pin = @params.ToLlamaContextParams(out var lparams); using var pin = @params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);


+ 1
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -47,7 +47,7 @@ namespace LLama
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context) public StatelessExecutor(LLamaContext context)
{ {
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding));
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
_params = context.Params; _params = context.Params;


Context = _weights.CreateContext(_params); Context = _weights.CreateContext(_params);


+ 1
- 1
LLama/LLamaWeights.cs View File

@@ -59,7 +59,7 @@ namespace LLama
if (!string.IsNullOrEmpty(@params.LoraAdapter)) if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);


return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding));
return new LLamaWeights(weights, @params.Encoding);
} }


/// <inheritdoc /> /// <inheritdoc />


Loading…
Cancel
Save