From d7f971fc2257e53fe1e9d791e4f528938a9562f6 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 12 Aug 2023 00:45:23 +0100 Subject: [PATCH] Improved `NativeApi` file a bit: - Added some more comments - Modified `llama_tokenize` to not allocate - Modified `llama_tokenize_native` to take a pointer instead of an array, allowing use with no allocations - Removed GgmlInitParams (not used) --- LLama/GlobalSuppressions.cs | 8 +++ LLama/Native/GgmlInitParams.cs | 15 ---- LLama/Native/NativeApi.cs | 126 +++++++++++++++++++++++++++------ 3 files changed, 114 insertions(+), 35 deletions(-) create mode 100644 LLama/GlobalSuppressions.cs delete mode 100644 LLama/Native/GgmlInitParams.cs diff --git a/LLama/GlobalSuppressions.cs b/LLama/GlobalSuppressions.cs new file mode 100644 index 00000000..4d4915ff --- /dev/null +++ b/LLama/GlobalSuppressions.cs @@ -0,0 +1,8 @@ +// This file is used by Code Analysis to maintain SuppressMessage +// attributes that are applied to this project. +// Project-level suppressions either have no target or are given +// a specific target and scoped to a namespace, type, member, etc. + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] diff --git a/LLama/Native/GgmlInitParams.cs b/LLama/Native/GgmlInitParams.cs deleted file mode 100644 index 834ceab9..00000000 --- a/LLama/Native/GgmlInitParams.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; - -namespace LLama.Native -{ - internal struct GgmlInitParams - { - public ulong mem_size; - public IntPtr mem_buffer; - [MarshalAs(UnmanagedType.I1)] - public bool no_alloc; - } -} diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 810e0e47..466e5263 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,9 +1,12 @@ using System; +using System.Buffers; using System.Runtime.InteropServices; using System.Text; using LLama.Common; using LLama.Exceptions; +#pragma warning disable IDE1006 // Naming Styles + namespace LLama.Native { using llama_token = Int32; @@ -13,6 +16,7 @@ namespace LLama.Native public unsafe partial class NativeApi { public static readonly int LLAMA_MAX_DEVICES = 1; + static NativeApi() { try @@ -32,6 +36,10 @@ namespace LLama.Native } private const string libraryName = "libllama"; + /// + /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. + /// + /// [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_empty_call(); @@ -59,10 +67,10 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); + public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); + public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// /// not great API - very likely to change. @@ -227,9 +235,6 @@ namespace LLama.Native /// /// Convert the provided text into tokens. - /// The tokens pointer must be large enough to hold the resulting tokens. - /// Returns the number of tokens on success, no more than n_max_tokens - /// Returns a negative number on failure - the number of tokens that would have been returned /// /// /// @@ -237,35 +242,72 @@ namespace LLama.Native /// /// /// - /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) { - var bytes = encoding.GetBytes(text); - sbyte[] data = new sbyte[bytes.Length]; - for(int i = 0; i < bytes.Length; i++) + // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) + var byteCount = encoding.GetByteCount(text); + var array = ArrayPool.Shared.Rent(byteCount + 1); + try + { + // Convert to bytes + fixed (char* textPtr = text) + fixed (byte* arrayPtr = array) + { + encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); + } + + // Add a zero byte to the end to terminate the string + array[byteCount] = 0; + + // Do the actual tokenization + fixed (byte* arrayPtr = array) + fixed (llama_token* tokensPtr = tokens) + return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); + } + finally { - data[i] = (sbyte)bytes[i]; - //if (bytes[i] < 128) - //{ - // data[i] = (sbyte)bytes[i]; - //} - //else - //{ - // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1); - //} + ArrayPool.Shared.Return(array); } - return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos); } + /// + /// Convert the provided text into tokens. + /// + /// + /// + /// + /// + /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos); + public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); + /// + /// Get the number of tokens in the model vocabulary for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); + /// + /// Get the size of the context window for the model for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); + /// + /// Get the dimension of embedding vectors from the model for this context + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_embd(SafeLLamaContextHandle ctx); @@ -308,9 +350,17 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern llama_token llama_token_nl(); + /// + /// Print out timing information for this context + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_print_timings(SafeLLamaContextHandle ctx); + /// + /// Reset all collected timing information for this context + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_reset_timings(SafeLLamaContextHandle ctx); @@ -321,21 +371,57 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr llama_print_system_info(); + /// + /// Get the number of tokens in the model vocabulary + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); + /// + /// Get the size of the context window for the model + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); + /// + /// Get the dimension of embedding vectors from this model + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); + /// + /// Convert a single token into text + /// + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); + /// + /// Convert text into tokens + /// + /// + /// + /// + /// + /// + /// Returns the number of tokens on success, no more than n_max_tokens. + /// Returns a negative number on failure - the number of tokens that would have been returned + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); + /// + /// Register a callback to receive llama log messages + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_log_set(LLamaLogCallback logCallback); }