Browse Source

Added metadata overrides to `IModelParams`

tags/0.9.1
Martin Evans 1 year ago
parent
commit
b868b056f7
9 changed files with 157 additions and 21 deletions
  1. +2
    -0
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +1
    -1
      LLama.Examples/LLama.Examples.csproj
  3. +1
    -2
      LLama.Examples/Program.cs
  4. +7
    -1
      LLama.Unittest/ModelsParamsTests.cs
  5. +3
    -0
      LLama.Web/Common/ModelOptions.cs
  6. +99
    -1
      LLama/Abstractions/IModelParams.cs
  7. +4
    -2
      LLama/Common/ModelParams.cs
  8. +39
    -13
      LLama/Extensions/IModelParamsExtensions.cs
  9. +1
    -1
      LLama/Native/LLamaModelMetadataOverride.cs

+ 2
- 0
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -1,5 +1,6 @@
using System.Diagnostics;
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;

@@ -30,6 +31,7 @@ public class BatchedDecoding

// Load model
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt


+ 1
- 1
LLama.Examples/LLama.Examples.csproj View File

@@ -2,7 +2,7 @@
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Platforms>AnyCPU;x64</Platforms>


+ 1
- 2
LLama.Examples/Program.cs View File

@@ -10,8 +10,7 @@ Console.WriteLine("=============================================================
NativeLibraryConfig
.Instance
.WithCuda()
.WithLogs()
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
.WithLogs();

NativeApi.llama_empty_call();
Console.WriteLine();


+ 7
- 1
LLama.Unittest/ModelsParamsTests.cs View File

@@ -1,5 +1,6 @@
using LLama.Common;
using System.Text.Json;
using LLama.Abstractions;

namespace LLama.Unittest
{
@@ -14,7 +15,12 @@ namespace LLama.Unittest
ContextSize = 42,
Seed = 42,
GpuLayerCount = 111,
TensorSplits = { [0] = 3 }
TensorSplits = { [0] = 3 },
MetadataOverrides =
{
MetadataOverride.Create("hello", true),
MetadataOverride.Create("world", 17),
}
};

var json = JsonSerializer.Serialize(expected);


+ 3
- 0
LLama.Web/Common/ModelOptions.cs View File

@@ -59,6 +59,9 @@ namespace LLama.Web.Common
/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();

/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();

/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }



+ 99
- 1
LLama/Abstractions/IModelParams.cs View File

@@ -59,6 +59,11 @@ namespace LLama.Abstractions
/// base model path for the lora adapter (lora_base)
/// </summary>
string LoraBase { get; set; }

/// <summary>
/// Override specific metadata items in the model
/// </summary>
List<MetadataOverride> MetadataOverrides { get; }
}

/// <summary>
@@ -186,7 +191,7 @@ namespace LLama.Abstractions
: JsonConverter<TensorSplitsCollection>
{
/// <inheritdoc/>
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
@@ -198,4 +203,97 @@ namespace LLama.Abstractions
JsonSerializer.Serialize(writer, value.Splits, options);
}
}

/// <summary>
/// An override for a single key/value pair in model metadata
/// </summary>
[JsonConverter(typeof(MetadataOverrideConverter))]
public abstract record MetadataOverride
{
/// <summary>
/// Create a new override for an int key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, int value)
{
return new IntOverride(key, value);
}

/// <summary>
/// Create a new override for a float key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, float value)
{
return new FloatOverride(key, value);
}

/// <summary>
/// Create a new override for a boolean key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, bool value)
{
return new BoolOverride(key, value);
}

internal abstract void Write(ref LLamaModelMetadataOverride dest);

/// <summary>
/// Get the key being overriden by this override
/// </summary>
public abstract string Key { get; init; }

private record IntOverride(string Key, int Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}
}

private record FloatOverride(string Key, float Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}
}

private record BoolOverride(string Key, bool Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}
}
}

public class MetadataOverrideConverter
: JsonConverter<MetadataOverride>
{
/// <inheritdoc/>
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
//return new TensorSplitsCollection(arr);
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
throw new NotImplementedException();
//JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

+ 4
- 2
LLama/Common/ModelParams.cs View File

@@ -1,9 +1,8 @@
using LLama.Abstractions;
using System;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
using System.Collections.Generic;

namespace LLama.Common
{
@@ -55,6 +54,9 @@ namespace LLama.Common
/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();

/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();

/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }



+ 39
- 13
LLama/Extensions/IModelParamsExtensions.cs View File

@@ -1,6 +1,6 @@
using System.IO;
using System;
using System.Buffers;
using System.Text;
using LLama.Abstractions;
using LLama.Native;

@@ -36,18 +36,44 @@ public static class IModelParamsExtensions
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}

//todo: MetadataOverrides
//if (@params.MetadataOverrides.Count == 0)
//{
// unsafe
// {
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
// }
//}
//else
//{
// throw new NotImplementedException("MetadataOverrides");
//}
if (@params.MetadataOverrides.Count == 0)
{
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
}
}
else
{
// Allocate enough space for all the override items
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
var overridesPin = overrides.AsMemory().Pin();
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
}

// Convert each item
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride();

// Init value and tag
item.Write(ref native);

// Convert key to bytes
unsafe
{
fixed (char* srcKey = item.Key)
{
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
}
}

overrides[i] = native;
}
}

return disposer;
}

+ 1
- 1
LLama/Native/LLamaModelMetadataOverride.cs View File

@@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
/// Key to override
/// </summary>
[FieldOffset(0)]
public fixed char key[128];
public fixed byte key[128];

/// <summary>
/// Type of value


Loading…
Cancel
Save