diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 1684e501..2846b2d3 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -1,7 +1,5 @@
using System;
-using System.Buffers;
using System.Runtime.InteropServices;
-using System.Text;
#pragma warning disable IDE1006 // Naming Styles
@@ -77,22 +75,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);
- ///
- /// Apply a LoRA adapter to a loaded model
- /// path_base_model is the path to a higher quality model to use as a base for
- /// the layers modified by the adapter. Can be NULL to use the current loaded model.
- /// The model needs to be reloaded before applying a new adapter, otherwise the adapter
- /// will be applied on top of the previous one
- ///
- ///
- ///
- ///
- ///
- ///
- /// Returns 0 on success
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);
-
///
/// Sets the current rng seed.
///
@@ -166,50 +148,6 @@ namespace LLama.Native
[Obsolete("use llama_decode() instead")]
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past);
- ///
- /// Convert the provided text into tokens.
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
- /// Returns the number of tokens on success, no more than n_max_tokens.
- /// Returns a negative number on failure - the number of tokens that would have been returned
- ///
- public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special)
- {
- unsafe
- {
- // Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
- var byteCount = encoding.GetByteCount(text);
- var array = ArrayPool.Shared.Rent(byteCount + 1);
- try
- {
- // Convert to bytes
- fixed (char* textPtr = text)
- fixed (byte* arrayPtr = array)
- {
- encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
- }
-
- // Add a zero byte to the end to terminate the string
- array[byteCount] = 0;
-
- // Do the actual tokenization
- fixed (byte* arrayPtr = array)
- fixed (LLamaToken* tokensPtr = tokens)
- return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
- }
- finally
- {
- ArrayPool.Shared.Return(array);
- }
- }
- }
-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);
@@ -349,125 +287,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_print_system_info();
- ///
- /// Get the number of tokens in the model vocabulary
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_vocab(SafeLlamaModelHandle model);
-
- ///
- /// Get the size of the context window for the model
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_ctx_train(SafeLlamaModelHandle model);
-
- ///
- /// Get the dimension of embedding vectors from this model
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_embd(SafeLlamaModelHandle model);
-
- ///
- /// Get the model's RoPE frequency scaling factor
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model);
-
- ///
- /// Get metadata value as a string by key name
- ///
- ///
- ///
- ///
- ///
- /// The length of the string on success, or -1 on failure
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
-
- ///
- /// Get the number of metadata key/value pairs
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_meta_count(SafeLlamaModelHandle model);
-
- ///
- /// Get metadata key name by index
- ///
- /// Model to fetch from
- /// Index of key to fetch
- /// buffer to write result into
- /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
- public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest)
- {
- unsafe
- {
- fixed (byte* destPtr = dest)
- {
- return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length);
- }
- }
-
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")]
- static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
- }
-
- ///
- /// Get metadata value as a string by index
- ///
- /// Model to fetch from
- /// Index of val to fetch
- /// Buffer to write result into
- /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
- public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest)
- {
- unsafe
- {
- fixed (byte* destPtr = dest)
- {
- return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length);
- }
- }
-
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")]
- static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
- }
-
- ///
- /// Get a string describing the model type
- ///
- ///
- ///
- ///
- /// The length of the string on success, or -1 on failure
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
-
- ///
- /// Get the size of the model in bytes
- ///
- ///
- /// The size of the model
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern ulong llama_model_size(SafeLlamaModelHandle model);
-
- ///
- /// Get the number of parameters in this model
- ///
- ///
- /// The functions return the length of the string on success, or -1 on failure
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern ulong llama_model_n_params(SafeLlamaModelHandle model);
-
///
/// Convert a single token into text
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index ae0c82e9..182ccb9b 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -19,33 +19,59 @@ namespace LLama.Native
///
/// Total number of tokens in vocabulary of this model
///
- public int VocabCount => NativeApi.llama_n_vocab(this);
+ public int VocabCount => llama_n_vocab(this);
///
/// Total number of tokens in the context
///
- public int ContextSize => NativeApi.llama_n_ctx_train(this);
+ public int ContextSize => llama_n_ctx_train(this);
+
+ ///
+ /// Get the rope frequency this model was trained with
+ ///
+ public float RopeFrequency => llama_rope_freq_scale_train(this);
///
/// Dimension of embedding vectors
///
- public int EmbeddingSize => NativeApi.llama_n_embd(this);
+ public int EmbeddingSize => llama_n_embd(this);
///
/// Get the size of this model in bytes
///
- public ulong SizeInBytes => NativeApi.llama_model_size(this);
+ public ulong SizeInBytes => llama_model_size(this);
///
/// Get the number of parameters in this model
///
- public ulong ParameterCount => NativeApi.llama_model_n_params(this);
+ public ulong ParameterCount => llama_model_n_params(this);
+
+ ///
+ /// Get a description of this model
+ ///
+ public string Description
+ {
+ get
+ {
+ unsafe
+ {
+ // Get description length
+ var size = llama_model_desc(this, null, 0);
+ var buf = new byte[size + 1];
+ fixed (byte* bufPtr = buf)
+ {
+ size = llama_model_desc(this, bufPtr, buf.Length);
+ return Encoding.UTF8.GetString(buf, 0, size);
+ }
+ }
+ }
+ }
///
/// Get the number of metadata key/value pairs
///
///
- public int MetadataCount => NativeApi.llama_model_meta_count(this);
+ public int MetadataCount => llama_model_meta_count(this);
///
protected override bool ReleaseHandle()
@@ -86,16 +112,150 @@ namespace LLama.Native
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);
+ ///
+ /// Apply a LoRA adapter to a loaded model
+ /// path_base_model is the path to a higher quality model to use as a base for
+ /// the layers modified by the adapter. Can be NULL to use the current loaded model.
+ /// The model needs to be reloaded before applying a new adapter, otherwise the adapter
+ /// will be applied on top of the previous one
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// Returns 0 on success
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);
+
///
/// Frees all allocated memory associated with a model
///
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free_model(IntPtr model);
+
+ ///
+ /// Get the number of metadata key/value pairs
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_model_meta_count(SafeLlamaModelHandle model);
+
+ ///
+ /// Get metadata key name by index
+ ///
+ /// Model to fetch from
+ /// Index of key to fetch
+ /// buffer to write result into
+ /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
+ private static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest)
+ {
+ unsafe
+ {
+ fixed (byte* destPtr = dest)
+ {
+ return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length);
+ }
+ }
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")]
+ static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
+ }
+
+ ///
+ /// Get metadata value as a string by index
+ ///
+ /// Model to fetch from
+ /// Index of val to fetch
+ /// Buffer to write result into
+ /// The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
+ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest)
+ {
+ unsafe
+ {
+ fixed (byte* destPtr = dest)
+ {
+ return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length);
+ }
+ }
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")]
+ static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
+ }
+
+ ///
+ /// Get metadata value as a string by key name
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The length of the string on success, or -1 on failure
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
+
+ ///
+ /// Get the number of tokens in the model vocabulary
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_n_vocab(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the size of the context window for the model
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_n_ctx_train(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the dimension of embedding vectors from this model
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_n_embd(SafeLlamaModelHandle model);
+
+ ///
+ /// Get a string describing the model type
+ ///
+ ///
+ ///
+ ///
+ /// The length of the string on success (even if the buffer is too small)., or -1 on failure
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
+
+ ///
+ /// Get the size of the model in bytes
+ ///
+ ///
+ /// The size of the model
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern ulong llama_model_size(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the number of parameters in this model
+ ///
+ ///
+ /// The functions return the length of the string on success, or -1 on failure
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern ulong llama_model_n_params(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the model's RoPE frequency scaling factor
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model);
#endregion
#region LoRA
-
///
/// Apply a LoRA adapter to a loaded model
///
@@ -107,7 +267,7 @@ namespace LLama.Native
///
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
- var err = NativeApi.llama_model_apply_lora_from_file(
+ var err = llama_model_apply_lora_from_file(
this,
lora,
scale,
@@ -231,13 +391,13 @@ namespace LLama.Native
public Memory? MetadataKeyByIndex(int index)
{
// Check if the key exists, without getting any bytes of data
- var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty());
+ var keyLength = llama_model_meta_key_by_index(this, index, Array.Empty());
if (keyLength < 0)
return null;
// get a buffer large enough to hold it
var buffer = new byte[keyLength + 1];
- keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer);
+ keyLength = llama_model_meta_key_by_index(this, index, buffer);
Debug.Assert(keyLength >= 0);
return buffer.AsMemory().Slice(0, keyLength);
@@ -251,13 +411,13 @@ namespace LLama.Native
public Memory? MetadataValueByIndex(int index)
{
// Check if the key exists, without getting any bytes of data
- var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty());
+ var valueLength = llama_model_meta_val_str_by_index(this, index, Array.Empty());
if (valueLength < 0)
return null;
// get a buffer large enough to hold it
var buffer = new byte[valueLength + 1];
- valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer);
+ valueLength = llama_model_meta_val_str_by_index(this, index, buffer);
Debug.Assert(valueLength >= 0);
return buffer.AsMemory().Slice(0, valueLength);