Low level new loading systemtags/v0.4.2-preview
| @@ -13,65 +13,101 @@ namespace LLama.Native | |||
| /// RNG seed, -1 for random | |||
| /// </summary> | |||
| public int seed; | |||
| /// <summary> | |||
| /// text context | |||
| /// </summary> | |||
| public int n_ctx; | |||
| /// <summary> | |||
| /// prompt processing batch size | |||
| /// </summary> | |||
| public int n_batch; | |||
| /// <summary> | |||
| /// grouped-query attention (TEMP - will be moved to model hparams) | |||
| /// </summary> | |||
| public int n_gqa; | |||
| /// <summary> | |||
| /// rms norm epsilon (TEMP - will be moved to model hparams) | |||
| /// </summary> | |||
| public float rms_norm_eps; | |||
| /// <summary> | |||
| /// number of layers to store in VRAM | |||
| /// </summary> | |||
| public int n_gpu_layers; | |||
| /// <summary> | |||
| /// the GPU that is used for scratch and small tensors | |||
| /// </summary> | |||
| public int main_gpu; | |||
| /// <summary> | |||
| /// how to split layers across multiple GPUs | |||
| /// </summary> | |||
| public TensorSplits tensor_split; | |||
| public float[] tensor_split; | |||
| /// <summary> | |||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |||
| /// RoPE base frequency | |||
| /// </summary> | |||
| public float rope_freq_base; | |||
| /// <summary> | |||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |||
| /// RoPE frequency scaling factor | |||
| /// </summary> | |||
| public float rope_freq_scale; | |||
| /// <summary> | |||
| /// called with a progress value between 0 and 1, pass NULL to disable | |||
| /// </summary> | |||
| public IntPtr progress_callback; | |||
| /// <summary> | |||
| /// context pointer passed to the progress callback | |||
| /// </summary> | |||
| public IntPtr progress_callback_user_data; | |||
| /// <summary> | |||
| /// if true, reduce VRAM usage at the cost of performance | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool low_vram; | |||
| /// <summary> | |||
| /// use fp16 for KV cache | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool f16_kv; | |||
| /// <summary> | |||
| /// the llama_eval() call computes all logits, not just the last one | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool logits_all; | |||
| /// <summary> | |||
| /// only load the vocabulary, no weights | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool vocab_only; | |||
| /// <summary> | |||
| /// use mmap if possible | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool use_mmap; | |||
| /// <summary> | |||
| /// force system to keep model in RAM | |||
| /// </summary> | |||
| [MarshalAs(UnmanagedType.I1)] | |||
| public bool use_mlock; | |||
| /// <summary> | |||
| /// embedding mode only | |||
| /// </summary> | |||
| @@ -1,6 +1,4 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -29,7 +27,7 @@ namespace LLama.Native | |||
| } | |||
| private const string libraryName = "libllama"; | |||
| [DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | |||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern bool llama_empty_call(); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| @@ -56,7 +54,10 @@ namespace LLama.Native | |||
| /// <param name="params_"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_); | |||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); | |||
| /// <summary> | |||
| /// not great API - very likely to change. | |||
| @@ -65,6 +66,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_backend_init(bool numa); | |||
| /// <summary> | |||
| /// Frees all allocated memory | |||
| /// </summary> | |||
| @@ -72,6 +74,13 @@ namespace LLama.Native | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_free(IntPtr ctx); | |||
| /// <summary> | |||
| /// Frees all allocated memory associated with a model | |||
| /// </summary> | |||
| /// <param name="model"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_free_model(IntPtr model); | |||
| /// <summary> | |||
| /// Apply a LoRA adapter to a loaded model | |||
| /// path_base_model is the path to a higher quality model to use as a base for | |||
| @@ -79,13 +88,13 @@ namespace LLama.Native | |||
| /// The model needs to be reloaded before applying a new adapter, otherwise the adapter | |||
| /// will be applied on top of the previous one | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="model_ptr"></param> | |||
| /// <param name="path_lora"></param> | |||
| /// <param name="path_base_model"></param> | |||
| /// <param name="n_threads"></param> | |||
| /// <returns>Returns 0 on success</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads); | |||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); | |||
| /// <summary> | |||
| /// Returns the number of tokens in the KV cache | |||
| @@ -294,5 +303,20 @@ namespace LLama.Native | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern IntPtr llama_print_system_info(); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | |||
| } | |||
| } | |||
| @@ -1,26 +1,61 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| namespace LLama.Native | |||
| { | |||
| public class SafeLLamaContextHandle: SafeLLamaHandleBase | |||
| /// <summary> | |||
| /// A safe wrapper around a llama_context | |||
| /// </summary> | |||
| public class SafeLLamaContextHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| protected SafeLLamaContextHandle() | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// This field guarantees that a reference to the model is held for as long as this handle is held | |||
| /// </summary> | |||
| private SafeLlamaModelHandle? _model; | |||
| public SafeLLamaContextHandle(IntPtr handle) | |||
| /// <summary> | |||
| /// Create a new SafeLLamaContextHandle | |||
| /// </summary> | |||
| /// <param name="handle">pointer to an allocated llama_context</param> | |||
| /// <param name="model">the model which this context was created from</param> | |||
| public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) | |||
| : base(handle) | |||
| { | |||
| // Increment the model reference count while this context exists | |||
| _model = model; | |||
| var success = false; | |||
| _model.DangerousAddRef(ref success); | |||
| if (!success) | |||
| throw new RuntimeError("Failed to increment model refcount"); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| // Decrement refcount on model | |||
| _model?.DangerousRelease(); | |||
| _model = null; | |||
| NativeApi.llama_free(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// Create a new llama_state for the given model | |||
| /// </summary> | |||
| /// <param name="model"></param> | |||
| /// <param name="lparams"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) | |||
| { | |||
| var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); | |||
| if (ctx_ptr == IntPtr.Zero) | |||
| throw new RuntimeError("Failed to create context from model"); | |||
| return new(ctx_ptr, model); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,11 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace LLama.Native | |||
| { | |||
| public abstract class SafeLLamaHandleBase: SafeHandle | |||
| /// <summary> | |||
| /// Base class for all llama handles to native resources | |||
| /// </summary> | |||
| public abstract class SafeLLamaHandleBase | |||
| : SafeHandle | |||
| { | |||
| private protected SafeLLamaHandleBase() | |||
| : base(IntPtr.Zero, ownsHandle: true) | |||
| @@ -24,8 +26,10 @@ namespace LLama.Native | |||
| SetHandle(handle); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override bool IsInvalid => handle == IntPtr.Zero; | |||
| /// <inheritdoc /> | |||
| public override string ToString() | |||
| => $"0x{handle.ToString("x16")}"; | |||
| } | |||
| @@ -0,0 +1,161 @@ | |||
| using System; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// A reference to a set of llama model weights | |||
| /// </summary> | |||
| public class SafeLlamaModelHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| /// <summary> | |||
| /// Total number of tokens in vocabulary of this model | |||
| /// </summary> | |||
| public int VocabCount { get; set; } | |||
| /// <summary> | |||
| /// Total number of tokens in the context | |||
| /// </summary> | |||
| public int ContextSize { get; set; } | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| public int EmbeddingCount { get; set; } | |||
| internal SafeLlamaModelHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| VocabCount = NativeApi.llama_n_vocab_from_model(this); | |||
| ContextSize = NativeApi.llama_n_ctx_from_model(this); | |||
| EmbeddingCount = NativeApi.llama_n_embd_from_model(this); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| NativeApi.llama_free_model(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// Load a model from the given file path into memory | |||
| /// </summary> | |||
| /// <param name="modelPath"></param> | |||
| /// <param name="lparams"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) | |||
| { | |||
| var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | |||
| if (model_ptr == IntPtr.Zero) | |||
| throw new RuntimeError($"Failed to load model {modelPath}."); | |||
| return new SafeLlamaModelHandle(model_ptr); | |||
| } | |||
| #region LoRA | |||
| /// <summary> | |||
| /// Apply a LoRA adapter to a loaded model | |||
| /// </summary> | |||
| /// <param name="lora"></param> | |||
| /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | |||
| /// adapter. Can be NULL to use the current loaded model.</param> | |||
| /// <param name="threads"></param> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) | |||
| { | |||
| var err = NativeApi.llama_model_apply_lora_from_file( | |||
| this, | |||
| lora, | |||
| string.IsNullOrEmpty(modelBase) ? null : modelBase, | |||
| threads | |||
| ); | |||
| if (err != 0) | |||
| throw new RuntimeError("Failed to apply lora adapter."); | |||
| } | |||
| #endregion | |||
| #region tokenize | |||
| /// <summary> | |||
| /// Convert a single llama token into string bytes | |||
| /// </summary> | |||
| /// <param name="llama_token"></param> | |||
| /// <returns></returns> | |||
| public ReadOnlySpan<byte> TokenToSpan(int llama_token) | |||
| { | |||
| unsafe | |||
| { | |||
| var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue); | |||
| var terminator = bytes.IndexOf((byte)0); | |||
| return bytes.Slice(0, terminator); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Convert a single llama token into a string | |||
| /// </summary> | |||
| /// <param name="llama_token"></param> | |||
| /// <param name="encoding">Encoding to use to decode the bytes into a string</param> | |||
| /// <returns></returns> | |||
| public string TokenToString(int llama_token, Encoding encoding) | |||
| { | |||
| var span = TokenToSpan(llama_token); | |||
| if (span.Length == 0) | |||
| return ""; | |||
| unsafe | |||
| { | |||
| fixed (byte* ptr = &span[0]) | |||
| { | |||
| return encoding.GetString(ptr, span.Length); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Convert a string of text into tokens | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="add_bos"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <returns></returns> | |||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||
| { | |||
| // Convert string to bytes, adding one extra byte to the end (null terminator) | |||
| var bytesCount = encoding.GetByteCount(text); | |||
| var bytes = new byte[bytesCount + 1]; | |||
| unsafe | |||
| { | |||
| fixed (char* charPtr = text) | |||
| fixed (byte* bytePtr = &bytes[0]) | |||
| { | |||
| encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length); | |||
| } | |||
| } | |||
| unsafe | |||
| { | |||
| fixed (byte* bytesPtr = &bytes[0]) | |||
| { | |||
| // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | |||
| var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); | |||
| // Tokenize again, this time outputting into an array of exactly the right size | |||
| var tokens = new int[count]; | |||
| fixed (int* tokensPtr = &tokens[0]) | |||
| { | |||
| count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||
| return tokens; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| } | |||
| @@ -31,24 +31,12 @@ namespace LLama.OldVersion | |||
| throw new FileNotFoundException($"The model file does not exist: {@params.model}"); | |||
| } | |||
| var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams); | |||
| if (ctx_ptr == IntPtr.Zero) | |||
| { | |||
| throw new RuntimeError($"Failed to load model {@params.model}."); | |||
| } | |||
| SafeLLamaContextHandle ctx = new(ctx_ptr); | |||
| var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams); | |||
| var ctx = SafeLLamaContextHandle.Create(model, lparams); | |||
| if (!string.IsNullOrEmpty(@params.lora_adapter)) | |||
| { | |||
| int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter, | |||
| string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads); | |||
| if (err != 0) | |||
| { | |||
| throw new RuntimeError("Failed to apply lora adapter."); | |||
| } | |||
| } | |||
| model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads); | |||
| return ctx; | |||
| } | |||
| @@ -28,40 +28,25 @@ namespace LLama | |||
| lparams.logits_all = @params.Perplexity; | |||
| lparams.embedding = @params.EmbeddingMode; | |||
| lparams.low_vram = @params.LowVram; | |||
| if(@params.TensorSplits.Length != 1) | |||
| if (@params.TensorSplits.Length != 1) | |||
| { | |||
| throw new ArgumentException("Currently multi-gpu support is not supported by " + | |||
| "both llama.cpp and LLamaSharp."); | |||
| } | |||
| lparams.tensor_split = new TensorSplits() | |||
| { | |||
| Item1 = @params.TensorSplits[0] | |||
| }; | |||
| lparams.tensor_split = @params.TensorSplits; | |||
| if (!File.Exists(@params.ModelPath)) | |||
| { | |||
| throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); | |||
| } | |||
| var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams); | |||
| if (ctx_ptr == IntPtr.Zero) | |||
| { | |||
| throw new RuntimeError($"Failed to load model {@params.ModelPath}."); | |||
| } | |||
| SafeLLamaContextHandle ctx = new(ctx_ptr); | |||
| var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | |||
| var ctx = SafeLLamaContextHandle.Create(model, lparams); | |||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | |||
| { | |||
| int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter, | |||
| string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads); | |||
| if (err != 0) | |||
| { | |||
| throw new RuntimeError("Failed to apply lora adapter."); | |||
| } | |||
| } | |||
| model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||
| return ctx; | |||
| } | |||
| @@ -78,7 +63,7 @@ namespace LLama | |||
| return res.Take(n); | |||
| } | |||
| public unsafe static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||
| public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||
| { | |||
| var logits = NativeApi.llama_get_logits(ctx); | |||
| return new Span<float>(logits, length); | |||