diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 3d0e2cab..42f2be3f 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -13,65 +13,101 @@ namespace LLama.Native /// RNG seed, -1 for random /// public int seed; + /// /// text context /// public int n_ctx; + /// /// prompt processing batch size /// public int n_batch; + + /// + /// grouped-query attention (TEMP - will be moved to model hparams) + /// + public int n_gqa; + + /// + /// rms norm epsilon (TEMP - will be moved to model hparams) + /// + public float rms_norm_eps; + /// /// number of layers to store in VRAM /// public int n_gpu_layers; + /// /// the GPU that is used for scratch and small tensors /// public int main_gpu; + /// /// how to split layers across multiple GPUs /// - public TensorSplits tensor_split; + public float[] tensor_split; + + /// + /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 + /// RoPE base frequency + /// + public float rope_freq_base; + + /// + /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 + /// RoPE frequency scaling factor + /// + public float rope_freq_scale; + /// /// called with a progress value between 0 and 1, pass NULL to disable /// public IntPtr progress_callback; + /// /// context pointer passed to the progress callback /// public IntPtr progress_callback_user_data; + /// /// if true, reduce VRAM usage at the cost of performance /// [MarshalAs(UnmanagedType.I1)] public bool low_vram; + /// /// use fp16 for KV cache /// [MarshalAs(UnmanagedType.I1)] public bool f16_kv; + /// /// the llama_eval() call computes all logits, not just the last one /// [MarshalAs(UnmanagedType.I1)] public bool logits_all; + /// /// only load the vocabulary, no weights /// [MarshalAs(UnmanagedType.I1)] public bool vocab_only; + /// /// use mmap if possible /// [MarshalAs(UnmanagedType.I1)] public bool use_mmap; + /// /// force system to keep model in RAM /// [MarshalAs(UnmanagedType.I1)] public bool use_mlock; + /// /// embedding mode only /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index b4d23007..527bea52 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,6 +1,4 @@ using System; -using System.Collections.Generic; -using System.IO; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -29,7 +27,7 @@ namespace LLama.Native } private const string libraryName = "libllama"; - [DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] + [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_empty_call(); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] @@ -56,7 +54,10 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_); + public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); /// /// not great API - very likely to change. @@ -65,6 +66,7 @@ namespace LLama.Native /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_backend_init(bool numa); + /// /// Frees all allocated memory /// @@ -72,6 +74,13 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_free(IntPtr ctx); + /// + /// Frees all allocated memory associated with a model + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_free_model(IntPtr model); + /// /// Apply a LoRA adapter to a loaded model /// path_base_model is the path to a higher quality model to use as a base for @@ -79,13 +88,13 @@ namespace LLama.Native /// The model needs to be reloaded before applying a new adapter, otherwise the adapter /// will be applied on top of the previous one /// - /// + /// /// /// /// /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads); + public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); /// /// Returns the number of tokens in the KV cache @@ -294,5 +303,20 @@ namespace LLama.Native /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr llama_print_system_info(); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 5c26cb13..ab102228 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,26 +1,61 @@ using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; +using LLama.Exceptions; namespace LLama.Native { - public class SafeLLamaContextHandle: SafeLLamaHandleBase + /// + /// A safe wrapper around a llama_context + /// + public class SafeLLamaContextHandle + : SafeLLamaHandleBase { - protected SafeLLamaContextHandle() - { - } + /// + /// This field guarantees that a reference to the model is held for as long as this handle is held + /// + private SafeLlamaModelHandle? _model; - public SafeLLamaContextHandle(IntPtr handle) + /// + /// Create a new SafeLLamaContextHandle + /// + /// pointer to an allocated llama_context + /// the model which this context was created from + public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) : base(handle) { + // Increment the model reference count while this context exists + _model = model; + var success = false; + _model.DangerousAddRef(ref success); + if (!success) + throw new RuntimeError("Failed to increment model refcount"); } + /// protected override bool ReleaseHandle() { + // Decrement refcount on model + _model?.DangerousRelease(); + _model = null; + NativeApi.llama_free(handle); SetHandle(IntPtr.Zero); return true; } + + /// + /// Create a new llama_state for the given model + /// + /// + /// + /// + /// + public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) + { + var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); + if (ctx_ptr == IntPtr.Zero) + throw new RuntimeError("Failed to create context from model"); + + return new(ctx_ptr, model); + } } } diff --git a/LLama/Native/SafeLLamaHandleBase.cs b/LLama/Native/SafeLLamaHandleBase.cs index 023f8cdd..6371b327 100644 --- a/LLama/Native/SafeLLamaHandleBase.cs +++ b/LLama/Native/SafeLLamaHandleBase.cs @@ -1,11 +1,13 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { - public abstract class SafeLLamaHandleBase: SafeHandle + /// + /// Base class for all llama handles to native resources + /// + public abstract class SafeLLamaHandleBase + : SafeHandle { private protected SafeLLamaHandleBase() : base(IntPtr.Zero, ownsHandle: true) @@ -24,8 +26,10 @@ namespace LLama.Native SetHandle(handle); } + /// public override bool IsInvalid => handle == IntPtr.Zero; + /// public override string ToString() => $"0x{handle.ToString("x16")}"; } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs new file mode 100644 index 00000000..79714fea --- /dev/null +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -0,0 +1,161 @@ +using System; +using System.Text; +using LLama.Exceptions; + +namespace LLama.Native +{ + /// + /// A reference to a set of llama model weights + /// + public class SafeLlamaModelHandle + : SafeLLamaHandleBase + { + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount { get; set; } + + /// + /// Total number of tokens in the context + /// + public int ContextSize { get; set; } + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingCount { get; set; } + + internal SafeLlamaModelHandle(IntPtr handle) + : base(handle) + { + VocabCount = NativeApi.llama_n_vocab_from_model(this); + ContextSize = NativeApi.llama_n_ctx_from_model(this); + EmbeddingCount = NativeApi.llama_n_embd_from_model(this); + } + + /// + protected override bool ReleaseHandle() + { + NativeApi.llama_free_model(handle); + SetHandle(IntPtr.Zero); + return true; + } + + /// + /// Load a model from the given file path into memory + /// + /// + /// + /// + /// + public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) + { + var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); + if (model_ptr == IntPtr.Zero) + throw new RuntimeError($"Failed to load model {modelPath}."); + + return new SafeLlamaModelHandle(model_ptr); + } + + #region LoRA + /// + /// Apply a LoRA adapter to a loaded model + /// + /// + /// A path to a higher quality model to use as a base for the layers modified by the + /// adapter. Can be NULL to use the current loaded model. + /// + /// + public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) + { + var err = NativeApi.llama_model_apply_lora_from_file( + this, + lora, + string.IsNullOrEmpty(modelBase) ? null : modelBase, + threads + ); + + if (err != 0) + throw new RuntimeError("Failed to apply lora adapter."); + } + #endregion + + #region tokenize + /// + /// Convert a single llama token into string bytes + /// + /// + /// + public ReadOnlySpan TokenToSpan(int llama_token) + { + unsafe + { + var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue); + var terminator = bytes.IndexOf((byte)0); + return bytes.Slice(0, terminator); + } + } + + /// + /// Convert a single llama token into a string + /// + /// + /// Encoding to use to decode the bytes into a string + /// + public string TokenToString(int llama_token, Encoding encoding) + { + var span = TokenToSpan(llama_token); + + if (span.Length == 0) + return ""; + + unsafe + { + fixed (byte* ptr = &span[0]) + { + return encoding.GetString(ptr, span.Length); + } + } + } + + /// + /// Convert a string of text into tokens + /// + /// + /// + /// + /// + public int[] Tokenize(string text, bool add_bos, Encoding encoding) + { + // Convert string to bytes, adding one extra byte to the end (null terminator) + var bytesCount = encoding.GetByteCount(text); + var bytes = new byte[bytesCount + 1]; + unsafe + { + fixed (char* charPtr = text) + fixed (byte* bytePtr = &bytes[0]) + { + encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length); + } + } + + unsafe + { + fixed (byte* bytesPtr = &bytes[0]) + { + // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) + var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); + + // Tokenize again, this time outputting into an array of exactly the right size + var tokens = new int[count]; + fixed (int* tokensPtr = &tokens[0]) + { + count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); + return tokens; + } + } + } + } + #endregion + } +} diff --git a/LLama/OldVersion/Utils.cs b/LLama/OldVersion/Utils.cs index 4916a20d..df8adddd 100644 --- a/LLama/OldVersion/Utils.cs +++ b/LLama/OldVersion/Utils.cs @@ -31,24 +31,12 @@ namespace LLama.OldVersion throw new FileNotFoundException($"The model file does not exist: {@params.model}"); } - var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams); - - if (ctx_ptr == IntPtr.Zero) - { - throw new RuntimeError($"Failed to load model {@params.model}."); - } - - SafeLLamaContextHandle ctx = new(ctx_ptr); + var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams); + var ctx = SafeLLamaContextHandle.Create(model, lparams); if (!string.IsNullOrEmpty(@params.lora_adapter)) - { - int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter, - string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads); - if (err != 0) - { - throw new RuntimeError("Failed to apply lora adapter."); - } - } + model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads); + return ctx; } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index b6f1b7b4..c08912cf 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -28,40 +28,25 @@ namespace LLama lparams.logits_all = @params.Perplexity; lparams.embedding = @params.EmbeddingMode; lparams.low_vram = @params.LowVram; - - if(@params.TensorSplits.Length != 1) + + if (@params.TensorSplits.Length != 1) { throw new ArgumentException("Currently multi-gpu support is not supported by " + "both llama.cpp and LLamaSharp."); } - lparams.tensor_split = new TensorSplits() - { - Item1 = @params.TensorSplits[0] - }; + lparams.tensor_split = @params.TensorSplits; if (!File.Exists(@params.ModelPath)) { throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); } - var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams); - - if (ctx_ptr == IntPtr.Zero) - { - throw new RuntimeError($"Failed to load model {@params.ModelPath}."); - } - - SafeLLamaContextHandle ctx = new(ctx_ptr); + var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + var ctx = SafeLLamaContextHandle.Create(model, lparams); if (!string.IsNullOrEmpty(@params.LoraAdapter)) - { - int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter, - string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads); - if (err != 0) - { - throw new RuntimeError("Failed to apply lora adapter."); - } - } + model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + return ctx; } @@ -78,7 +63,7 @@ namespace LLama return res.Take(n); } - public unsafe static Span GetLogits(SafeLLamaContextHandle ctx, int length) + public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length) { var logits = NativeApi.llama_get_logits(ctx); return new Span(logits, length);