Browse Source

Moved some native methods into `SafeLlamaModelHandle`, these methods are all wrapped in safer accessors with no extra costs so there is no need to expose them.

tags/v0.10.0
Martin Evans 1 year ago
parent
commit
ce1d302e7e
2 changed files with 172 additions and 193 deletions
  1. +0
    -181
      LLama/Native/NativeApi.cs
  2. +172
    -12
      LLama/Native/SafeLlamaModelHandle.cs

+ 0
- 181
LLama/Native/NativeApi.cs View File

@@ -1,7 +1,5 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;

#pragma warning disable IDE1006 // Naming Styles

@@ -77,22 +75,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
/// the layers modified by the adapter. Can be NULL to use the current loaded model.
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
/// will be applied on top of the previous one
/// </summary>
/// <param name="model_ptr"></param>
/// <param name="path_lora"></param>
/// <param name="scale"></param>
/// <param name="path_base_model"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);

/// <summary>
/// Sets the current rng seed.
/// </summary>
@@ -166,50 +148,6 @@ namespace LLama.Native
[Obsolete("use llama_decode() instead")]
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past);

/// <summary>
/// Convert the provided text into tokens.
/// </summary>
/// <param name="ctx"></param>
/// <param name="text"></param>
/// <param name="encoding"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special)
{
unsafe
{
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text);
var array = ArrayPool<byte>.Shared.Rent(byteCount + 1);
try
{
// Convert to bytes
fixed (char* textPtr = text)
fixed (byte* arrayPtr = array)
{
encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
}

// Add a zero byte to the end to terminate the string
array[byteCount] = 0;

// Do the actual tokenization
fixed (byte* arrayPtr = array)
fixed (LLamaToken* tokensPtr = tokens)
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);

@@ -349,125 +287,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_print_system_info();

/// <summary>
/// Get the number of tokens in the model vocabulary
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_vocab(SafeLlamaModelHandle model);

/// <summary>
/// Get the size of the context window for the model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_ctx_train(SafeLlamaModelHandle model);

/// <summary>
/// Get the dimension of embedding vectors from this model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_embd(SafeLlamaModelHandle model);

/// <summary>
/// Get the model's RoPE frequency scaling factor
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model);

/// <summary>
/// Get metadata value as a string by key name
/// </summary>
/// <param name="model"></param>
/// <param name="key"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_meta_count(SafeLlamaModelHandle model);

/// <summary>
/// Get metadata key name by index
/// </summary>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of key to fetch</param>
/// <param name="dest">buffer to write result into</param>
/// <returns>The length of the string on success (even if the buffer is too small). -1 is the key does not exist.</returns>
public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")]
static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get metadata value as a string by index
/// </summary>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of val to fetch</param>
/// <param name="dest">Buffer to write result into</param>
/// <returns>The length of the string on success (even if the buffer is too small). -1 is the key does not exist.</returns>
public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")]
static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get a string describing the model type
/// </summary>
/// <param name="model"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);

/// <summary>
/// Get the size of the model in bytes
/// </summary>
/// <param name="model"></param>
/// <returns>The size of the model</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_size(SafeLlamaModelHandle model);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
/// <param name="model"></param>
/// <returns>The functions return the length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_n_params(SafeLlamaModelHandle model);

/// <summary>
/// Convert a single token into text
/// </summary>


+ 172
- 12
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -19,33 +19,59 @@ namespace LLama.Native
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => NativeApi.llama_n_vocab(this);
public int VocabCount => llama_n_vocab(this);

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => NativeApi.llama_n_ctx_train(this);
public int ContextSize => llama_n_ctx_train(this);

/// <summary>
/// Get the rope frequency this model was trained with
/// </summary>
public float RopeFrequency => llama_rope_freq_scale_train(this);

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => NativeApi.llama_n_embd(this);
public int EmbeddingSize => llama_n_embd(this);

/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes => NativeApi.llama_model_size(this);
public ulong SizeInBytes => llama_model_size(this);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount => NativeApi.llama_model_n_params(this);
public ulong ParameterCount => llama_model_n_params(this);

/// <summary>
/// Get a description of this model
/// </summary>
public string Description
{
get
{
unsafe
{
// Get description length
var size = llama_model_desc(this, null, 0);
var buf = new byte[size + 1];
fixed (byte* bufPtr = buf)
{
size = llama_model_desc(this, bufPtr, buf.Length);
return Encoding.UTF8.GetString(buf, 0, size);
}
}
}
}

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <returns></returns>
public int MetadataCount => NativeApi.llama_model_meta_count(this);
public int MetadataCount => llama_model_meta_count(this);

/// <inheritdoc />
protected override bool ReleaseHandle()
@@ -86,16 +112,150 @@ namespace LLama.Native
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
/// the layers modified by the adapter. Can be NULL to use the current loaded model.
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
/// will be applied on top of the previous one
/// </summary>
/// <param name="model_ptr"></param>
/// <param name="path_lora"></param>
/// <param name="scale"></param>
/// <param name="path_base_model"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free_model(IntPtr model);

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_model_meta_count(SafeLlamaModelHandle model);

/// <summary>
/// Get metadata key name by index
/// </summary>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of key to fetch</param>
/// <param name="dest">buffer to write result into</param>
/// <returns>The length of the string on success (even if the buffer is too small). -1 is the key does not exist.</returns>
private static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")]
static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get metadata value as a string by index
/// </summary>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of val to fetch</param>
/// <param name="dest">Buffer to write result into</param>
/// <returns>The length of the string on success (even if the buffer is too small). -1 is the key does not exist.</returns>
private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")]
static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get metadata value as a string by key name
/// </summary>
/// <param name="model"></param>
/// <param name="key"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);

/// <summary>
/// Get the number of tokens in the model vocabulary
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_n_vocab(SafeLlamaModelHandle model);

/// <summary>
/// Get the size of the context window for the model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_n_ctx_train(SafeLlamaModelHandle model);

/// <summary>
/// Get the dimension of embedding vectors from this model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_n_embd(SafeLlamaModelHandle model);

/// <summary>
/// Get a string describing the model type
/// </summary>
/// <param name="model"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of the string on success (even if the buffer is too small)., or -1 on failure</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);

/// <summary>
/// Get the size of the model in bytes
/// </summary>
/// <param name="model"></param>
/// <returns>The size of the model</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern ulong llama_model_size(SafeLlamaModelHandle model);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
/// <param name="model"></param>
/// <returns>The functions return the length of the string on success, or -1 on failure</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern ulong llama_model_n_params(SafeLlamaModelHandle model);

/// <summary>
/// Get the model's RoPE frequency scaling factor
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model);
#endregion

#region LoRA

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// </summary>
@@ -107,7 +267,7 @@ namespace LLama.Native
/// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
var err = llama_model_apply_lora_from_file(
this,
lora,
scale,
@@ -231,13 +391,13 @@ namespace LLama.Native
public Memory<byte>? MetadataKeyByIndex(int index)
{
// Check if the key exists, without getting any bytes of data
var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty<byte>());
var keyLength = llama_model_meta_key_by_index(this, index, Array.Empty<byte>());
if (keyLength < 0)
return null;

// get a buffer large enough to hold it
var buffer = new byte[keyLength + 1];
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer);
keyLength = llama_model_meta_key_by_index(this, index, buffer);
Debug.Assert(keyLength >= 0);

return buffer.AsMemory().Slice(0, keyLength);
@@ -251,13 +411,13 @@ namespace LLama.Native
public Memory<byte>? MetadataValueByIndex(int index)
{
// Check if the key exists, without getting any bytes of data
var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty<byte>());
var valueLength = llama_model_meta_val_str_by_index(this, index, Array.Empty<byte>());
if (valueLength < 0)
return null;

// get a buffer large enough to hold it
var buffer = new byte[valueLength + 1];
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer);
valueLength = llama_model_meta_val_str_by_index(this, index, buffer);
Debug.Assert(valueLength >= 0);

return buffer.AsMemory().Slice(0, valueLength);


Loading…
Cancel
Save