diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 1684e501..2846b2d3 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,7 +1,5 @@ using System; -using System.Buffers; using System.Runtime.InteropServices; -using System.Text; #pragma warning disable IDE1006 // Naming Styles @@ -77,22 +75,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_backend_init(bool numa); - /// - /// Apply a LoRA adapter to a loaded model - /// path_base_model is the path to a higher quality model to use as a base for - /// the layers modified by the adapter. Can be NULL to use the current loaded model. - /// The model needs to be reloaded before applying a new adapter, otherwise the adapter - /// will be applied on top of the previous one - /// - /// - /// - /// - /// - /// - /// Returns 0 on success - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); - /// /// Sets the current rng seed. /// @@ -166,50 +148,6 @@ namespace LLama.Native [Obsolete("use llama_decode() instead")] public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past); - /// - /// Convert the provided text into tokens. - /// - /// - /// - /// - /// - /// - /// - /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. - /// Returns the number of tokens on success, no more than n_max_tokens. - /// Returns a negative number on failure - the number of tokens that would have been returned - /// - public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special) - { - unsafe - { - // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) - var byteCount = encoding.GetByteCount(text); - var array = ArrayPool.Shared.Rent(byteCount + 1); - try - { - // Convert to bytes - fixed (char* textPtr = text) - fixed (byte* arrayPtr = array) - { - encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); - } - - // Add a zero byte to the end to terminate the string - array[byteCount] = 0; - - // Do the actual tokenization - fixed (byte* arrayPtr = array) - fixed (LLamaToken* tokensPtr = tokens) - return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); - } - finally - { - ArrayPool.Shared.Return(array); - } - } - } - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); @@ -349,125 +287,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr llama_print_system_info(); - /// - /// Get the number of tokens in the model vocabulary - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_vocab(SafeLlamaModelHandle model); - - /// - /// Get the size of the context window for the model - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_ctx_train(SafeLlamaModelHandle model); - - /// - /// Get the dimension of embedding vectors from this model - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_embd(SafeLlamaModelHandle model); - - /// - /// Get the model's RoPE frequency scaling factor - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model); - - /// - /// Get metadata value as a string by key name - /// - /// - /// - /// - /// - /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); - - /// - /// Get the number of metadata key/value pairs - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_count(SafeLlamaModelHandle model); - - /// - /// Get metadata key name by index - /// - /// Model to fetch from - /// Index of key to fetch - /// buffer to write result into - /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist. - public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest) - { - unsafe - { - fixed (byte* destPtr = dest) - { - return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length); - } - } - - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")] - static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); - } - - /// - /// Get metadata value as a string by index - /// - /// Model to fetch from - /// Index of val to fetch - /// Buffer to write result into - /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist. - public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest) - { - unsafe - { - fixed (byte* destPtr = dest) - { - return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length); - } - } - - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")] - static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); - } - - /// - /// Get a string describing the model type - /// - /// - /// - /// - /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); - - /// - /// Get the size of the model in bytes - /// - /// - /// The size of the model - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern ulong llama_model_size(SafeLlamaModelHandle model); - - /// - /// Get the number of parameters in this model - /// - /// - /// The functions return the length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern ulong llama_model_n_params(SafeLlamaModelHandle model); - /// /// Convert a single token into text /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index ae0c82e9..182ccb9b 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -19,33 +19,59 @@ namespace LLama.Native /// /// Total number of tokens in vocabulary of this model /// - public int VocabCount => NativeApi.llama_n_vocab(this); + public int VocabCount => llama_n_vocab(this); /// /// Total number of tokens in the context /// - public int ContextSize => NativeApi.llama_n_ctx_train(this); + public int ContextSize => llama_n_ctx_train(this); + + /// + /// Get the rope frequency this model was trained with + /// + public float RopeFrequency => llama_rope_freq_scale_train(this); /// /// Dimension of embedding vectors /// - public int EmbeddingSize => NativeApi.llama_n_embd(this); + public int EmbeddingSize => llama_n_embd(this); /// /// Get the size of this model in bytes /// - public ulong SizeInBytes => NativeApi.llama_model_size(this); + public ulong SizeInBytes => llama_model_size(this); /// /// Get the number of parameters in this model /// - public ulong ParameterCount => NativeApi.llama_model_n_params(this); + public ulong ParameterCount => llama_model_n_params(this); + + /// + /// Get a description of this model + /// + public string Description + { + get + { + unsafe + { + // Get description length + var size = llama_model_desc(this, null, 0); + var buf = new byte[size + 1]; + fixed (byte* bufPtr = buf) + { + size = llama_model_desc(this, bufPtr, buf.Length); + return Encoding.UTF8.GetString(buf, 0, size); + } + } + } + } /// /// Get the number of metadata key/value pairs /// /// - public int MetadataCount => NativeApi.llama_model_meta_count(this); + public int MetadataCount => llama_model_meta_count(this); /// protected override bool ReleaseHandle() @@ -86,16 +112,150 @@ namespace LLama.Native [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); + /// + /// Apply a LoRA adapter to a loaded model + /// path_base_model is the path to a higher quality model to use as a base for + /// the layers modified by the adapter. Can be NULL to use the current loaded model. + /// The model needs to be reloaded before applying a new adapter, otherwise the adapter + /// will be applied on top of the previous one + /// + /// + /// + /// + /// + /// + /// Returns 0 on success + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); + /// /// Frees all allocated memory associated with a model /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_free_model(IntPtr model); + + /// + /// Get the number of metadata key/value pairs + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_model_meta_count(SafeLlamaModelHandle model); + + /// + /// Get metadata key name by index + /// + /// Model to fetch from + /// Index of key to fetch + /// buffer to write result into + /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist. + private static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = dest) + { + return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")] + static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } + + /// + /// Get metadata value as a string by index + /// + /// Model to fetch from + /// Index of val to fetch + /// Buffer to write result into + /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist. + private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = dest) + { + return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")] + static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } + + /// + /// Get metadata value as a string by key name + /// + /// + /// + /// + /// + /// The length of the string on success, or -1 on failure + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + + /// + /// Get the number of tokens in the model vocabulary + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_n_vocab(SafeLlamaModelHandle model); + + /// + /// Get the size of the context window for the model + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_n_ctx_train(SafeLlamaModelHandle model); + + /// + /// Get the dimension of embedding vectors from this model + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_n_embd(SafeLlamaModelHandle model); + + /// + /// Get a string describing the model type + /// + /// + /// + /// + /// The length of the string on success (even if the buffer is too small)., or -1 on failure + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); + + /// + /// Get the size of the model in bytes + /// + /// + /// The size of the model + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern ulong llama_model_size(SafeLlamaModelHandle model); + + /// + /// Get the number of parameters in this model + /// + /// + /// The functions return the length of the string on success, or -1 on failure + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern ulong llama_model_n_params(SafeLlamaModelHandle model); + + /// + /// Get the model's RoPE frequency scaling factor + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model); #endregion #region LoRA - /// /// Apply a LoRA adapter to a loaded model /// @@ -107,7 +267,7 @@ namespace LLama.Native /// public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null) { - var err = NativeApi.llama_model_apply_lora_from_file( + var err = llama_model_apply_lora_from_file( this, lora, scale, @@ -231,13 +391,13 @@ namespace LLama.Native public Memory? MetadataKeyByIndex(int index) { // Check if the key exists, without getting any bytes of data - var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty()); + var keyLength = llama_model_meta_key_by_index(this, index, Array.Empty()); if (keyLength < 0) return null; // get a buffer large enough to hold it var buffer = new byte[keyLength + 1]; - keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer); + keyLength = llama_model_meta_key_by_index(this, index, buffer); Debug.Assert(keyLength >= 0); return buffer.AsMemory().Slice(0, keyLength); @@ -251,13 +411,13 @@ namespace LLama.Native public Memory? MetadataValueByIndex(int index) { // Check if the key exists, without getting any bytes of data - var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty()); + var valueLength = llama_model_meta_val_str_by_index(this, index, Array.Empty()); if (valueLength < 0) return null; // get a buffer large enough to hold it var buffer = new byte[valueLength + 1]; - valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer); + valueLength = llama_model_meta_val_str_by_index(this, index, buffer); Debug.Assert(valueLength >= 0); return buffer.AsMemory().Slice(0, valueLength);