@@ -80,7 +80,16 @@ namespace LLama.Native
/// Call once at the start of the program
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);
private static extern void llama_backend_init();
// Note: this is not implemented because we don't have a definition for `ggml_numa_strategy` in C#. That definition doesn't
// exist because it's not in llama.h, it's in ggml.h which we don't currently build a wrapper for. If there's demand
// for better NUMA support that will need adding.
///// <summary>
///// Optional, enable NUMA optimisations
///// </summary>
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//public static extern void llama_numa_init(ggml_numa_strategy numa);
/// <summary>
/// Sets the current rng seed.
@@ -187,6 +196,13 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
/// <summary>
/// Get the embeddings for the ith sequence. Equivalent to: llama_get_embeddings(ctx) + i*n_embd
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i);
/// <summary>
/// Get the embeddings for the input
/// </summary>
@@ -204,6 +220,22 @@ namespace LLama.Native
static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx);
}
/// <summary>
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
/// </summary>
/// <param name="model"></param>
/// <param name="tmpl">A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.</param>
/// <param name="chat">Pointer to a list of multiple llama_chat_message</param>
/// <param name="n_msg">Number of llama_chat_message in this chat</param>
/// <param name="add_ass">Whether to end the prompt with the token(s) that indicate the start of an assistant message.</param>
/// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param>
/// <param name="length">The size of the allocated buffer</param>
/// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nint n_msg, bool add_ass, char* buf, int length);
/// <summary>
/// Get the "Beginning of sentence" token
/// </summary>
@@ -371,7 +403,9 @@ namespace LLama.Native
/// <summary>
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// If the KV cache is RoPEd, the KV data is updated accordingly:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
@@ -379,12 +413,16 @@ namespace LLama.Native
/// <param name="p1"></param>
/// <param name="delta"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift (SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);
public static extern void llama_kv_cache_seq_add (SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);
/// <summary>
/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// If the KV cache is RoPEd, the KV data is updated accordingly:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// <br />
/// p0 < 0 : [0, p1]
/// <br />
/// p1 < 0 : [p0, inf)
/// </summary>
/// <param name="ctx"></param>
@@ -395,6 +433,32 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d);
/// <summary>
/// Returns the largest position present in the KV cache for the specified sequence
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq);
/// <summary>
/// Defragment the KV cache. This will be applied:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaPos llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
/// <summary>
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
/// <summary>
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
@@ -438,5 +502,11 @@ namespace LLama.Native
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model);
}
}