Browse Source

Fixed some issues which were causing metadata overrides not to work (mostly importantly, converting the key was failing so all keys were null bytes and thus ignored).

tags/0.9.1
Martin Evans 1 year ago
parent
commit
3fc0f34cbe
2 changed files with 41 additions and 17 deletions
  1. +20
    -0
      LLama/Extensions/EncodingExtensions.cs
  2. +21
    -17
      LLama/Extensions/IModelParamsExtensions.cs

+ 20
- 0
LLama/Extensions/EncodingExtensions.cs View File

@@ -6,6 +6,11 @@ namespace LLama.Extensions;
internal static class EncodingExtensions
{
#if NETSTANDARD2_0
public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Span<byte> output)
{
return GetBytesImpl(encoding, chars, output);
}

public static int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
{
return GetCharsImpl(encoding, bytes, output);
@@ -19,6 +24,21 @@ internal static class EncodingExtensions
#error Target framework not supported!
#endif

internal static int GetBytesImpl(Encoding encoding, ReadOnlySpan<char> chars, Span<byte> output)
{
if (chars.Length == 0)
return 0;

unsafe
{
fixed (char* charPtr = chars)
fixed (byte* bytePtr = output)
{
return encoding.GetBytes(charPtr, chars.Length, bytePtr, output.Length);
}
}
}

internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
{
if (bytes.Length == 0)


+ 21
- 17
LLama/Extensions/IModelParamsExtensions.cs View File

@@ -45,35 +45,39 @@ public static class IModelParamsExtensions
}
else
{
// Allocate enough space for all the override items
// Allocate enough space for all the override items. Pin it in place so we can safely pass it to llama.cpp
// This is one larger than necessary. The last item indicates the end of the overrides.
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
var overridesPin = overrides.AsMemory().Pin();
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overrides.AsMemory().Pin()).Pointer;
}

// Convert each item
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride
{
Tag = item.Type
};

item.WriteValue(ref native);

// Convert key to bytes
unsafe
{
fixed (char* srcKey = item.Key)
// Get the item to convert
var item = @params.MetadataOverrides[i];

// Create the "native" representation to fill in
var native = new LLamaModelMetadataOverride
{
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
}
}
Tag = item.Type
};

// Write the value into the native struct
item.WriteValue(ref native);

overrides[i] = native;
// Convert key chars to bytes
var srcSpan = item.Key.AsSpan();
var dstSpan = new Span<byte>(native.key, 128);
Encoding.UTF8.GetBytes(srcSpan, dstSpan);

// Store it in the array
overrides[i] = native;
}
}
}



Loading…
Cancel
Save