| @@ -1,5 +1,6 @@ | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| @@ -30,6 +31,7 @@ public class BatchedDecoding | |||
| // Load model | |||
| var parameters = new ModelParams(modelPath); | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| // Tokenize prompt | |||
| @@ -2,7 +2,7 @@ | |||
| <Import Project="..\LLama\LLamaSharp.Runtime.targets" /> | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks> | |||
| <TargetFrameworks>net6.0;net8.0</TargetFrameworks> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Nullable>enable</Nullable> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| @@ -10,8 +10,7 @@ Console.WriteLine("============================================================= | |||
| NativeLibraryConfig | |||
| .Instance | |||
| .WithCuda() | |||
| .WithLogs() | |||
| .WithAvx(NativeLibraryConfig.AvxLevel.Avx512); | |||
| .WithLogs(); | |||
| NativeApi.llama_empty_call(); | |||
| Console.WriteLine(); | |||
| @@ -1,5 +1,6 @@ | |||
| using LLama.Common; | |||
| using System.Text.Json; | |||
| using LLama.Abstractions; | |||
| namespace LLama.Unittest | |||
| { | |||
| @@ -14,7 +15,12 @@ namespace LLama.Unittest | |||
| ContextSize = 42, | |||
| Seed = 42, | |||
| GpuLayerCount = 111, | |||
| TensorSplits = { [0] = 3 } | |||
| TensorSplits = { [0] = 3 }, | |||
| MetadataOverrides = | |||
| { | |||
| MetadataOverride.Create("hello", true), | |||
| MetadataOverride.Create("world", 17), | |||
| } | |||
| }; | |||
| var json = JsonSerializer.Serialize(expected); | |||
| @@ -59,6 +59,9 @@ namespace LLama.Web.Common | |||
| /// <inheritdoc /> | |||
| public TensorSplitsCollection TensorSplits { get; set; } = new(); | |||
| /// <inheritdoc /> | |||
| public List<MetadataOverride> MetadataOverrides { get; } = new(); | |||
| /// <inheritdoc /> | |||
| public float? RopeFrequencyBase { get; set; } | |||
| @@ -59,6 +59,11 @@ namespace LLama.Abstractions | |||
| /// base model path for the lora adapter (lora_base) | |||
| /// </summary> | |||
| string LoraBase { get; set; } | |||
| /// <summary> | |||
| /// Override specific metadata items in the model | |||
| /// </summary> | |||
| List<MetadataOverride> MetadataOverrides { get; } | |||
| } | |||
| /// <summary> | |||
| @@ -186,7 +191,7 @@ namespace LLama.Abstractions | |||
| : JsonConverter<TensorSplitsCollection> | |||
| { | |||
| /// <inheritdoc/> | |||
| public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); | |||
| return new TensorSplitsCollection(arr); | |||
| @@ -198,4 +203,97 @@ namespace LLama.Abstractions | |||
| JsonSerializer.Serialize(writer, value.Splits, options); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// An override for a single key/value pair in model metadata | |||
| /// </summary> | |||
| [JsonConverter(typeof(MetadataOverrideConverter))] | |||
| public abstract record MetadataOverride | |||
| { | |||
| /// <summary> | |||
| /// Create a new override for an int key | |||
| /// </summary> | |||
| /// <param name="key"></param> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static MetadataOverride Create(string key, int value) | |||
| { | |||
| return new IntOverride(key, value); | |||
| } | |||
| /// <summary> | |||
| /// Create a new override for a float key | |||
| /// </summary> | |||
| /// <param name="key"></param> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static MetadataOverride Create(string key, float value) | |||
| { | |||
| return new FloatOverride(key, value); | |||
| } | |||
| /// <summary> | |||
| /// Create a new override for a boolean key | |||
| /// </summary> | |||
| /// <param name="key"></param> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static MetadataOverride Create(string key, bool value) | |||
| { | |||
| return new BoolOverride(key, value); | |||
| } | |||
| internal abstract void Write(ref LLamaModelMetadataOverride dest); | |||
| /// <summary> | |||
| /// Get the key being overriden by this override | |||
| /// </summary> | |||
| public abstract string Key { get; init; } | |||
| private record IntOverride(string Key, int Value) : MetadataOverride | |||
| { | |||
| internal override void Write(ref LLamaModelMetadataOverride dest) | |||
| { | |||
| dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT; | |||
| dest.IntValue = Value; | |||
| } | |||
| } | |||
| private record FloatOverride(string Key, float Value) : MetadataOverride | |||
| { | |||
| internal override void Write(ref LLamaModelMetadataOverride dest) | |||
| { | |||
| dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT; | |||
| dest.FloatValue = Value; | |||
| } | |||
| } | |||
| private record BoolOverride(string Key, bool Value) : MetadataOverride | |||
| { | |||
| internal override void Write(ref LLamaModelMetadataOverride dest) | |||
| { | |||
| dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL; | |||
| dest.BoolValue = Value ? -1 : 0; | |||
| } | |||
| } | |||
| } | |||
| public class MetadataOverrideConverter | |||
| : JsonConverter<MetadataOverride> | |||
| { | |||
| /// <inheritdoc/> | |||
| public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); | |||
| //return new TensorSplitsCollection(arr); | |||
| } | |||
| /// <inheritdoc/> | |||
| public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //JsonSerializer.Serialize(writer, value.Splits, options); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,9 +1,8 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| using LLama.Native; | |||
| using System.Collections.Generic; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -55,6 +54,9 @@ namespace LLama.Common | |||
| /// <inheritdoc /> | |||
| public TensorSplitsCollection TensorSplits { get; set; } = new(); | |||
| /// <inheritdoc /> | |||
| public List<MetadataOverride> MetadataOverrides { get; } = new(); | |||
| /// <inheritdoc /> | |||
| public float? RopeFrequencyBase { get; set; } | |||
| @@ -1,6 +1,6 @@ | |||
| using System.IO; | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Text; | |||
| using LLama.Abstractions; | |||
| using LLama.Native; | |||
| @@ -36,18 +36,44 @@ public static class IModelParamsExtensions | |||
| result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer; | |||
| } | |||
| //todo: MetadataOverrides | |||
| //if (@params.MetadataOverrides.Count == 0) | |||
| //{ | |||
| // unsafe | |||
| // { | |||
| // result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero; | |||
| // } | |||
| //} | |||
| //else | |||
| //{ | |||
| // throw new NotImplementedException("MetadataOverrides"); | |||
| //} | |||
| if (@params.MetadataOverrides.Count == 0) | |||
| { | |||
| unsafe | |||
| { | |||
| result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // Allocate enough space for all the override items | |||
| var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1]; | |||
| var overridesPin = overrides.AsMemory().Pin(); | |||
| unsafe | |||
| { | |||
| result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer; | |||
| } | |||
| // Convert each item | |||
| for (var i = 0; i < @params.MetadataOverrides.Count; i++) | |||
| { | |||
| var item = @params.MetadataOverrides[i]; | |||
| var native = new LLamaModelMetadataOverride(); | |||
| // Init value and tag | |||
| item.Write(ref native); | |||
| // Convert key to bytes | |||
| unsafe | |||
| { | |||
| fixed (char* srcKey = item.Key) | |||
| { | |||
| Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128); | |||
| } | |||
| } | |||
| overrides[i] = native; | |||
| } | |||
| } | |||
| return disposer; | |||
| } | |||
| @@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride | |||
| /// Key to override | |||
| /// </summary> | |||
| [FieldOffset(0)] | |||
| public fixed char key[128]; | |||
| public fixed byte key[128]; | |||
| /// <summary> | |||
| /// Type of value | |||