From 669ae47ef744be5c4cb9b78d317bce641d40fa3a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 30 Sep 2023 16:21:18 +0100 Subject: [PATCH] - Split parameters into two interfaces - params contains a list of loras, instead of just one --- .../NewVersion/LoadAndSaveSession.cs | 4 +- LLama.Unittest/BasicTest.cs | 1 - LLama.Unittest/LLamaContextTests.cs | 7 +- LLama.Unittest/ModelsParamsTests.cs | 9 +- LLama.Web/Common/ModelOptions.cs | 13 +- LLama/Abstractions/IContextParams.cs | 60 +++++++++ LLama/Abstractions/ILLamaParams.cs | 11 ++ LLama/Abstractions/IModelParams.cs | 114 ++++++++---------- LLama/Common/ModelParams.cs | 18 ++- LLama/Extensions/IModelParamsExtensions.cs | 2 +- LLama/LLamaContext.cs | 25 +--- LLama/LLamaEmbedder.cs | 21 ++-- LLama/LLamaStatelessExecutor.cs | 18 +-- LLama/LLamaWeights.cs | 24 ++-- LLama/Utils.cs | 108 ----------------- 15 files changed, 178 insertions(+), 257 deletions(-) create mode 100644 LLama/Abstractions/IContextParams.cs create mode 100644 LLama/Abstractions/ILLamaParams.cs delete mode 100644 LLama/Utils.cs diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 33774b13..9e6116ce 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); var parameters = new ModelParams(modelPath) { @@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion Console.ForegroundColor = ConsoleColor.White; ex.Context.Dispose(); - ex = new(new LLamaContext(parameters)); + ex = new(new LLamaContext(model, parameters)); session = new ChatSession(ex); session.LoadSession(statePath); diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 93d82192..2cd1806f 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -29,7 +29,6 @@ namespace LLama.Unittest Assert.Equal(32000, _model.VocabCount); Assert.Equal(4096, _model.ContextSize); Assert.Equal(4096, _model.EmbeddingSize); - Assert.Equal(Encoding.UTF8, _model.Encoding); } } } \ No newline at end of file diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 6a181734..2edf3a62 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -10,7 +10,10 @@ namespace LLama.Unittest public LLamaContextTests() { - var @params = new ModelParams(Constants.ModelPath); + var @params = new ModelParams(Constants.ModelPath) + { + ContextSize = 768, + }; _weights = LLamaWeights.LoadFromFile(@params); _context = _weights.CreateContext(@params); } @@ -24,7 +27,7 @@ namespace LLama.Unittest [Fact] public void CheckProperties() { - Assert.Equal(4096, _context.ContextSize); + Assert.Equal(768, _context.ContextSize); Assert.Equal(4096, _context.EmbeddingSize); Assert.Equal(32000, _context.VocabCount); } diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index 413bda83..000f5853 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -13,7 +13,6 @@ namespace LLama.Unittest { BatchSize = 17, ContextSize = 42, - LoraAdapter = "adapter", Seed = 42, GpuLayerCount = 111 }; @@ -31,9 +30,13 @@ namespace LLama.Unittest { BatchSize = 17, ContextSize = 42, - LoraAdapter = "adapter", Seed = 42, - GpuLayerCount = 111 + GpuLayerCount = 111, + LoraAdapters = + { + new("abc", 1), + new("def", 0) + } }; var settings = new Newtonsoft.Json.JsonSerializerSettings(); diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index c5b5c54b..2fd8558c 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -4,7 +4,7 @@ using LLama.Abstractions; namespace LLama.Web.Common { public class ModelOptions - : IModelParams + : ILLamaParams { public string Name { get; set; } @@ -51,16 +51,11 @@ namespace LLama.Web.Common /// Model path (model) /// public string ModelPath { get; set; } + /// - /// model alias - /// - public string ModelAlias { get; set; } = "unknown"; - /// - /// lora adapter path (lora_adapter) + /// List of LoRAs to apply /// - public string LoraAdapter { get; set; } = string.Empty; - - public float LoraAdapterScale { get; set; } = 1; + public AdapterCollection LoraAdapters { get; set; } = new(); /// /// base model path for the lora adapter (lora_base) diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs new file mode 100644 index 00000000..a59512d3 --- /dev/null +++ b/LLama/Abstractions/IContextParams.cs @@ -0,0 +1,60 @@ +using System.Text; + +namespace LLama.Abstractions; + +/// +/// The parameters for initializing a LLama context from a model. +/// +public interface IContextParams +{ + /// + /// Model context size (n_ctx) + /// + uint ContextSize { get; set; } + + /// + /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// + uint BatchSize { get; set; } + + /// + /// Seed for the random number generator (seed) + /// + uint Seed { get; set; } + + /// + /// Use f16 instead of f32 for memory kv (memory_f16) + /// + bool UseFp16Memory { get; set; } + + /// + /// Compute perplexity over the prompt (perplexity) + /// + bool Perplexity { get; set; } + + /// + /// Whether to use embedding mode. (embedding) Note that if this is set to true, + /// The LLamaModel won't produce text response anymore. + /// + bool EmbeddingMode { get; set; } + + /// + /// RoPE base frequency + /// + float RopeFrequencyBase { get; set; } + + /// + /// RoPE frequency scaling factor + /// + float RopeFrequencyScale { get; set; } + + /// + /// Use experimental mul_mat_q kernels + /// + bool MulMatQ { get; set; } + + /// + /// The encoding to use for models + /// + Encoding Encoding { get; set; } +} \ No newline at end of file diff --git a/LLama/Abstractions/ILLamaParams.cs b/LLama/Abstractions/ILLamaParams.cs new file mode 100644 index 00000000..636ba199 --- /dev/null +++ b/LLama/Abstractions/ILLamaParams.cs @@ -0,0 +1,11 @@ +namespace LLama.Abstractions +{ + /// + /// Convenience interface for implementing both type of parameters. + /// + /// Mostly exists for backwards compatibility reasons, when these two were not split. + public interface ILLamaParams + : IModelParams, IContextParams + { + } +} diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 168654c4..31304acb 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,4 +1,6 @@ -using System.Text; +using System; +using System.Collections.Generic; +using System.Linq; namespace LLama.Abstractions { @@ -7,36 +9,16 @@ namespace LLama.Abstractions /// public interface IModelParams { - /// - /// Model context size (n_ctx) - /// - uint ContextSize { get; set; } - /// /// the GPU that is used for scratch and small tensors /// int MainGpu { get; set; } - /// - /// if true, reduce VRAM usage at the cost of performance - /// - bool LowVram { get; set; } - /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// int GpuLayerCount { get; set; } - /// - /// Seed for the random number generator (seed) - /// - uint Seed { get; set; } - - /// - /// Use f16 instead of f32 for memory kv (memory_f16) - /// - bool UseFp16Memory { get; set; } - /// /// Use mmap for faster loads (use_mmap) /// @@ -47,72 +29,78 @@ namespace LLama.Abstractions /// bool UseMemoryLock { get; set; } - /// - /// Compute perplexity over the prompt (perplexity) - /// - bool Perplexity { get; set; } - /// /// Model path (model) /// string ModelPath { get; set; } - /// - /// lora adapter path (lora_adapter) - /// - string LoraAdapter { get; set; } - - float LoraAdapterScale { get; set; } - - /// - /// base model path for the lora adapter (lora_base) - /// - string LoraBase { get; set; } - /// /// Number of threads (-1 = autodetect) (n_threads) /// int Threads { get; set; } - /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// - uint BatchSize { get; set; } - - /// - /// Whether to use embedding mode. (embedding) Note that if this is set to true, - /// The LLamaModel won't produce text response anymore. - /// - bool EmbeddingMode { get; set; } - /// /// how split tensors should be distributed across GPUs /// float[]? TensorSplits { get; set; } /// - /// RoPE base frequency + /// Load vocab only (no weights) /// - float RopeFrequencyBase { get; set; } + bool VocabOnly { get; set; } /// - /// RoPE frequency scaling factor + /// List of LoRA adapters to apply /// - float RopeFrequencyScale { get; set; } + AdapterCollection LoraAdapters { get; } /// - /// Use experimental mul_mat_q kernels + /// base model path for the lora adapter (lora_base) /// - bool MulMatQ { get; set; } + string LoraBase { get; set; } + } - /// - /// The encoding to use for models - /// - Encoding Encoding { get; set; } + /// + /// A LoRA adapter to apply to a model + /// + /// Path to the LoRA file + /// Strength of this LoRA + public readonly record struct LoraAdapter(string Path, float Scale); - /// - /// Load vocab only (no weights) - /// - bool VocabOnly { get; set; } + /// + /// A list of LoraAdapter objects + /// + public sealed class AdapterCollection + : List, IEquatable + { + /// + public bool Equals(AdapterCollection? other) + { + if (other == null) + return false; + + return this.SequenceEqual(other); + } + + /// + public override bool Equals(object? obj) + { + return Equals(obj as AdapterCollection); + } + + /// + public override int GetHashCode() + { + unchecked + { + var hash = 17; + for (var i = 0; i < Count; i++) + { + hash += this[i].GetHashCode(); + hash *= 7823; + } + return hash; + } + } } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index a0d1688a..09b5e4af 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using System; +using System.Collections.Generic; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -10,7 +11,7 @@ namespace LLama.Common /// The parameters for initializing a LLama model. /// public record ModelParams - : IModelParams + : ILLamaParams { /// /// Model context size (n_ctx) @@ -20,10 +21,7 @@ namespace LLama.Common /// the GPU that is used for scratch and small tensors /// public int MainGpu { get; set; } = 0; - /// - /// if true, reduce VRAM usage at the cost of performance - /// - public bool LowVram { get; set; } = false; + /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// @@ -52,17 +50,17 @@ namespace LLama.Common /// Model path (model) /// public string ModelPath { get; set; } + /// - /// lora adapter path (lora_adapter) + /// List of LoRAs to apply /// - public string LoraAdapter { get; set; } = string.Empty; - - public float LoraAdapterScale { get; set; } = 1; + public AdapterCollection LoraAdapters { get; set; } = new(); /// /// base model path for the lora adapter (lora_base) /// public string LoraBase { get; set; } = string.Empty; + /// /// Number of threads (-1 = autodetect) (n_threads) /// @@ -162,7 +160,6 @@ namespace LLama.Common UseMemoryLock = useMemoryLock; Perplexity = perplexity; ModelPath = modelPath; - LoraAdapter = loraAdapter; LoraBase = loraBase; Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads; BatchSize = batchSize; @@ -171,6 +168,7 @@ namespace LLama.Common RopeFrequencyScale = ropeFrequencyScale; MulMatQ = mulMatQ; Encoding = Encoding.GetEncoding(encoding); + LoraAdapters.Add(new LoraAdapter(loraAdapter, 1)); } } diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 1bf19958..9be239df 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -19,7 +19,7 @@ namespace LLama.Extensions /// /// /// - public static void ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) + public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result) { result = NativeApi.llama_context_default_params(); result.n_ctx = @params.ContextSize; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9fef6af5..e6222dcf 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -42,9 +42,9 @@ namespace LLama public int EmbeddingSize => _ctx.EmbeddingSize; /// - /// The model params set for this model. + /// The context params set for this context /// - public IModelParams Params { get; set; } + public IContextParams Params { get; set; } /// /// The native handle, which is used to be passed to the native APIs @@ -57,24 +57,7 @@ namespace LLama /// public Encoding Encoding => _encoding; - /// - /// - /// - /// Model params. - /// The logger. - [Obsolete("Use the LLamaWeights.CreateContext instead")] - public LLamaContext(IModelParams @params, ILogger? logger = null) - { - Params = @params; - - _logger = logger; - _encoding = @params.Encoding; - - _logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}"); - _ctx = Utils.InitLLamaContextFromModelParams(Params); - } - - internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) { Params = @params; @@ -90,7 +73,7 @@ namespace LLama /// /// /// - public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null) + public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null) { if (model.NativeHandle.IsClosed) throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 64c17539..54ef07b0 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -18,19 +18,22 @@ namespace LLama /// public int EmbeddingSize => _ctx.EmbeddingSize; - /// - /// - /// - /// - public LLamaEmbedder(IModelParams @params) + public LLamaEmbedder(ILLamaParams allParams) + : this(allParams, allParams) { - @params.EmbeddingMode = true; - using var weights = LLamaWeights.LoadFromFile(@params); - _ctx = weights.CreateContext(@params); } - public LLamaEmbedder(LLamaWeights weights, IModelParams @params) + public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) { + using var weights = LLamaWeights.LoadFromFile(modelParams); + + contextParams.EmbeddingMode = true; + _ctx = weights.CreateContext(contextParams); + } + + public LLamaEmbedder(LLamaWeights weights, IContextParams @params) + { + @params.EmbeddingMode = true; _ctx = weights.CreateContext(@params); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 6854a1f6..ad47541e 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -20,7 +20,7 @@ namespace LLama : ILLamaExecutor { private readonly LLamaWeights _weights; - private readonly IModelParams _params; + private readonly IContextParams _params; /// /// The context used by the executor when running the inference. @@ -32,7 +32,7 @@ namespace LLama /// /// /// - public StatelessExecutor(LLamaWeights weights, IModelParams @params) + public StatelessExecutor(LLamaWeights weights, IContextParams @params) { _weights = weights; _params = @params; @@ -41,20 +41,6 @@ namespace LLama Context.Dispose(); } - /// - /// Create a new stateless executor which will use the model used to create the given context - /// - /// - [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] - public StatelessExecutor(LLamaContext context) - { - _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); - _params = context.Params; - - Context = _weights.CreateContext(_params); - Context.Dispose(); - } - /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index e59f2990..bcc41afb 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,5 +1,4 @@ using System; -using System.Text; using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -20,11 +19,6 @@ namespace LLama /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle => _weights; - /// - /// Encoding to use to convert text into bytes for the model - /// - public Encoding Encoding { get; } - /// /// Total number of tokens in vocabulary of this model /// @@ -50,10 +44,9 @@ namespace LLama /// public int EmbeddingSize => NativeHandle.EmbeddingSize; - internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) + internal LLamaWeights(SafeLlamaModelHandle weights) { _weights = weights; - Encoding = encoding; } /// @@ -66,10 +59,17 @@ namespace LLama using var pin = @params.ToLlamaModelParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraAdapterScale, @params.LoraBase, @params.Threads); + foreach (var adapter in @params.LoraAdapters) + { + if (string.IsNullOrEmpty(adapter.Path)) + continue; + if (adapter.Scale <= 0) + continue; + + weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads); + } - return new LLamaWeights(weights, @params.Encoding); + return new LLamaWeights(weights); } /// @@ -83,7 +83,7 @@ namespace LLama /// /// /// - public LLamaContext CreateContext(IModelParams @params) + public LLamaContext CreateContext(IContextParams @params) { return new LLamaContext(this, @params); } diff --git a/LLama/Utils.cs b/LLama/Utils.cs deleted file mode 100644 index d08501c0..00000000 --- a/LLama/Utils.cs +++ /dev/null @@ -1,108 +0,0 @@ -using LLama.Abstractions; -using LLama.Native; -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; -using LLama.Extensions; - -namespace LLama -{ - using llama_token = Int32; - - /// - /// Assorted llama utilities - /// - public static class Utils - { - [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - using var weights = LLamaWeights.LoadFromFile(@params); - - @params.ToLlamaContextParams(out var lparams); - return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); - } - - [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - return ctx.Tokenize(text, add_bos, encoding); - } - - [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static Span GetLogits(SafeLLamaContextHandle ctx, int length) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - if (length != ctx.VocabCount) - throw new ArgumentException("length must be the VocabSize"); - - return ctx.GetLogits(); - } - - [Obsolete("Use SafeLLamaContextHandle Eval method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - var slice = tokens.AsSpan().Slice(startIndex, n_tokens); - return ctx.Eval(slice, n_past) ? 0 : 1; - } - - [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - return ctx.TokenToString(token, encoding); - } - - [Obsolete("No longer used internally by LlamaSharp")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static string PtrToString(IntPtr ptr, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { -#if NET6_0_OR_GREATER - // ReSharper disable once PossibleUnintendedReferenceComparison - if(encoding == Encoding.UTF8) - { - return Marshal.PtrToStringUTF8(ptr)!; - } - // ReSharper disable once PossibleUnintendedReferenceComparison - else if(encoding == Encoding.Unicode) - { - return Marshal.PtrToStringUni(ptr)!; - } - else - { - return Marshal.PtrToStringAuto(ptr)!; - } -#else - unsafe - { - byte* tp = (byte*)ptr.ToPointer(); - List bytes = new(); - while (true) - { - byte c = *tp++; - if (c == '\0') - { - break; - } - else - { - bytes.Add(c); - } - } - return encoding.GetString(bytes.ToArray()); - } -#endif - } - - } -}