From 3fc0f34cbe4614a36085278ec36e37762de232bd Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 24 Dec 2023 21:23:37 +0000 Subject: [PATCH] Fixed some issues which were causing metadata overrides not to work (mostly importantly, converting the key was failing so all keys were null bytes and thus ignored). --- LLama/Extensions/EncodingExtensions.cs | 20 ++++++++++++ LLama/Extensions/IModelParamsExtensions.cs | 38 ++++++++++++---------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs index 7df4fdc6..00b46f21 100644 --- a/LLama/Extensions/EncodingExtensions.cs +++ b/LLama/Extensions/EncodingExtensions.cs @@ -6,6 +6,11 @@ namespace LLama.Extensions; internal static class EncodingExtensions { #if NETSTANDARD2_0 + public static int GetBytes(this Encoding encoding, ReadOnlySpan chars, Span output) + { + return GetBytesImpl(encoding, chars, output); + } + public static int GetChars(this Encoding encoding, ReadOnlySpan bytes, Span output) { return GetCharsImpl(encoding, bytes, output); @@ -19,6 +24,21 @@ internal static class EncodingExtensions #error Target framework not supported! #endif + internal static int GetBytesImpl(Encoding encoding, ReadOnlySpan chars, Span output) + { + if (chars.Length == 0) + return 0; + + unsafe + { + fixed (char* charPtr = chars) + fixed (byte* bytePtr = output) + { + return encoding.GetBytes(charPtr, chars.Length, bytePtr, output.Length); + } + } + } + internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan bytes, Span output) { if (bytes.Length == 0) diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 08805d32..36558e72 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -45,35 +45,39 @@ public static class IModelParamsExtensions } else { - // Allocate enough space for all the override items + // Allocate enough space for all the override items. Pin it in place so we can safely pass it to llama.cpp + // This is one larger than necessary. The last item indicates the end of the overrides. var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1]; - var overridesPin = overrides.AsMemory().Pin(); unsafe { - result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer; + result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overrides.AsMemory().Pin()).Pointer; } // Convert each item for (var i = 0; i < @params.MetadataOverrides.Count; i++) { - var item = @params.MetadataOverrides[i]; - var native = new LLamaModelMetadataOverride - { - Tag = item.Type - }; - - item.WriteValue(ref native); - - // Convert key to bytes unsafe { - fixed (char* srcKey = item.Key) + // Get the item to convert + var item = @params.MetadataOverrides[i]; + + // Create the "native" representation to fill in + var native = new LLamaModelMetadataOverride { - Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128); - } - } + Tag = item.Type + }; + + // Write the value into the native struct + item.WriteValue(ref native); - overrides[i] = native; + // Convert key chars to bytes + var srcSpan = item.Key.AsSpan(); + var dstSpan = new Span(native.key, 128); + Encoding.UTF8.GetBytes(srcSpan, dstSpan); + + // Store it in the array + overrides[i] = native; + } } }