diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 3d0e2cab..42f2be3f 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -13,65 +13,101 @@ namespace LLama.Native
/// RNG seed, -1 for random
///
public int seed;
+
///
/// text context
///
public int n_ctx;
+
///
/// prompt processing batch size
///
public int n_batch;
+
+ ///
+ /// grouped-query attention (TEMP - will be moved to model hparams)
+ ///
+ public int n_gqa;
+
+ ///
+ /// rms norm epsilon (TEMP - will be moved to model hparams)
+ ///
+ public float rms_norm_eps;
+
///
/// number of layers to store in VRAM
///
public int n_gpu_layers;
+
///
/// the GPU that is used for scratch and small tensors
///
public int main_gpu;
+
///
/// how to split layers across multiple GPUs
///
- public TensorSplits tensor_split;
+ public float[] tensor_split;
+
+ ///
+ /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
+ /// RoPE base frequency
+ ///
+ public float rope_freq_base;
+
+ ///
+ /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
+ /// RoPE frequency scaling factor
+ ///
+ public float rope_freq_scale;
+
///
/// called with a progress value between 0 and 1, pass NULL to disable
///
public IntPtr progress_callback;
+
///
/// context pointer passed to the progress callback
///
public IntPtr progress_callback_user_data;
+
///
/// if true, reduce VRAM usage at the cost of performance
///
[MarshalAs(UnmanagedType.I1)]
public bool low_vram;
+
///
/// use fp16 for KV cache
///
[MarshalAs(UnmanagedType.I1)]
public bool f16_kv;
+
///
/// the llama_eval() call computes all logits, not just the last one
///
[MarshalAs(UnmanagedType.I1)]
public bool logits_all;
+
///
/// only load the vocabulary, no weights
///
[MarshalAs(UnmanagedType.I1)]
public bool vocab_only;
+
///
/// use mmap if possible
///
[MarshalAs(UnmanagedType.I1)]
public bool use_mmap;
+
///
/// force system to keep model in RAM
///
[MarshalAs(UnmanagedType.I1)]
public bool use_mlock;
+
///
/// embedding mode only
///
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index b4d23007..527bea52 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -1,6 +1,4 @@
using System;
-using System.Collections.Generic;
-using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@@ -29,7 +27,7 @@ namespace LLama.Native
}
private const string libraryName = "libllama";
- [DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
+ [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
@@ -56,7 +54,10 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_);
+ public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_);
///
/// not great API - very likely to change.
@@ -65,6 +66,7 @@ namespace LLama.Native
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_backend_init(bool numa);
+
///
/// Frees all allocated memory
///
@@ -72,6 +74,13 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free(IntPtr ctx);
+ ///
+ /// Frees all allocated memory associated with a model
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_free_model(IntPtr model);
+
///
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
@@ -79,13 +88,13 @@ namespace LLama.Native
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
/// will be applied on top of the previous one
///
- ///
+ ///
///
///
///
/// Returns 0 on success
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads);
+ public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);
///
/// Returns the number of tokens in the KV cache
@@ -294,5 +303,20 @@ namespace LLama.Native
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_print_system_info();
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
}
}
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 5c26cb13..ab102228 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,26 +1,61 @@
using System;
-using System.Collections.Generic;
-using System.Runtime.InteropServices;
-using System.Text;
+using LLama.Exceptions;
namespace LLama.Native
{
- public class SafeLLamaContextHandle: SafeLLamaHandleBase
+ ///
+ /// A safe wrapper around a llama_context
+ ///
+ public class SafeLLamaContextHandle
+ : SafeLLamaHandleBase
{
- protected SafeLLamaContextHandle()
- {
- }
+ ///
+ /// This field guarantees that a reference to the model is held for as long as this handle is held
+ ///
+ private SafeLlamaModelHandle? _model;
- public SafeLLamaContextHandle(IntPtr handle)
+ ///
+ /// Create a new SafeLLamaContextHandle
+ ///
+ /// pointer to an allocated llama_context
+ /// the model which this context was created from
+ public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
+ // Increment the model reference count while this context exists
+ _model = model;
+ var success = false;
+ _model.DangerousAddRef(ref success);
+ if (!success)
+ throw new RuntimeError("Failed to increment model refcount");
}
+ ///
protected override bool ReleaseHandle()
{
+ // Decrement refcount on model
+ _model?.DangerousRelease();
+ _model = null;
+
NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
+
+ ///
+ /// Create a new llama_state for the given model
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
+ {
+ var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
+ if (ctx_ptr == IntPtr.Zero)
+ throw new RuntimeError("Failed to create context from model");
+
+ return new(ctx_ptr, model);
+ }
}
}
diff --git a/LLama/Native/SafeLLamaHandleBase.cs b/LLama/Native/SafeLLamaHandleBase.cs
index 023f8cdd..6371b327 100644
--- a/LLama/Native/SafeLLamaHandleBase.cs
+++ b/LLama/Native/SafeLLamaHandleBase.cs
@@ -1,11 +1,13 @@
using System;
-using System.Collections.Generic;
using System.Runtime.InteropServices;
-using System.Text;
namespace LLama.Native
{
- public abstract class SafeLLamaHandleBase: SafeHandle
+ ///
+ /// Base class for all llama handles to native resources
+ ///
+ public abstract class SafeLLamaHandleBase
+ : SafeHandle
{
private protected SafeLLamaHandleBase()
: base(IntPtr.Zero, ownsHandle: true)
@@ -24,8 +26,10 @@ namespace LLama.Native
SetHandle(handle);
}
+ ///
public override bool IsInvalid => handle == IntPtr.Zero;
+ ///
public override string ToString()
=> $"0x{handle.ToString("x16")}";
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
new file mode 100644
index 00000000..79714fea
--- /dev/null
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -0,0 +1,161 @@
+using System;
+using System.Text;
+using LLama.Exceptions;
+
+namespace LLama.Native
+{
+ ///
+ /// A reference to a set of llama model weights
+ ///
+ public class SafeLlamaModelHandle
+ : SafeLLamaHandleBase
+ {
+ ///
+ /// Total number of tokens in vocabulary of this model
+ ///
+ public int VocabCount { get; set; }
+
+ ///
+ /// Total number of tokens in the context
+ ///
+ public int ContextSize { get; set; }
+
+ ///
+ /// Dimension of embedding vectors
+ ///
+ public int EmbeddingCount { get; set; }
+
+ internal SafeLlamaModelHandle(IntPtr handle)
+ : base(handle)
+ {
+ VocabCount = NativeApi.llama_n_vocab_from_model(this);
+ ContextSize = NativeApi.llama_n_ctx_from_model(this);
+ EmbeddingCount = NativeApi.llama_n_embd_from_model(this);
+ }
+
+ ///
+ protected override bool ReleaseHandle()
+ {
+ NativeApi.llama_free_model(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+
+ ///
+ /// Load a model from the given file path into memory
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
+ {
+ var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
+ if (model_ptr == IntPtr.Zero)
+ throw new RuntimeError($"Failed to load model {modelPath}.");
+
+ return new SafeLlamaModelHandle(model_ptr);
+ }
+
+ #region LoRA
+ ///
+ /// Apply a LoRA adapter to a loaded model
+ ///
+ ///
+ /// A path to a higher quality model to use as a base for the layers modified by the
+ /// adapter. Can be NULL to use the current loaded model.
+ ///
+ ///
+ public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
+ {
+ var err = NativeApi.llama_model_apply_lora_from_file(
+ this,
+ lora,
+ string.IsNullOrEmpty(modelBase) ? null : modelBase,
+ threads
+ );
+
+ if (err != 0)
+ throw new RuntimeError("Failed to apply lora adapter.");
+ }
+ #endregion
+
+ #region tokenize
+ ///
+ /// Convert a single llama token into string bytes
+ ///
+ ///
+ ///
+ public ReadOnlySpan TokenToSpan(int llama_token)
+ {
+ unsafe
+ {
+ var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
+ var terminator = bytes.IndexOf((byte)0);
+ return bytes.Slice(0, terminator);
+ }
+ }
+
+ ///
+ /// Convert a single llama token into a string
+ ///
+ ///
+ /// Encoding to use to decode the bytes into a string
+ ///
+ public string TokenToString(int llama_token, Encoding encoding)
+ {
+ var span = TokenToSpan(llama_token);
+
+ if (span.Length == 0)
+ return "";
+
+ unsafe
+ {
+ fixed (byte* ptr = &span[0])
+ {
+ return encoding.GetString(ptr, span.Length);
+ }
+ }
+ }
+
+ ///
+ /// Convert a string of text into tokens
+ ///
+ ///
+ ///
+ ///
+ ///
+ public int[] Tokenize(string text, bool add_bos, Encoding encoding)
+ {
+ // Convert string to bytes, adding one extra byte to the end (null terminator)
+ var bytesCount = encoding.GetByteCount(text);
+ var bytes = new byte[bytesCount + 1];
+ unsafe
+ {
+ fixed (char* charPtr = text)
+ fixed (byte* bytePtr = &bytes[0])
+ {
+ encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
+ }
+ }
+
+ unsafe
+ {
+ fixed (byte* bytesPtr = &bytes[0])
+ {
+ // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
+ var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
+
+ // Tokenize again, this time outputting into an array of exactly the right size
+ var tokens = new int[count];
+ fixed (int* tokensPtr = &tokens[0])
+ {
+ count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
+ return tokens;
+ }
+ }
+ }
+ }
+ #endregion
+ }
+}
diff --git a/LLama/OldVersion/Utils.cs b/LLama/OldVersion/Utils.cs
index 4916a20d..df8adddd 100644
--- a/LLama/OldVersion/Utils.cs
+++ b/LLama/OldVersion/Utils.cs
@@ -31,24 +31,12 @@ namespace LLama.OldVersion
throw new FileNotFoundException($"The model file does not exist: {@params.model}");
}
- var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams);
-
- if (ctx_ptr == IntPtr.Zero)
- {
- throw new RuntimeError($"Failed to load model {@params.model}.");
- }
-
- SafeLLamaContextHandle ctx = new(ctx_ptr);
+ var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams);
+ var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.lora_adapter))
- {
- int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter,
- string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads);
- if (err != 0)
- {
- throw new RuntimeError("Failed to apply lora adapter.");
- }
- }
+ model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads);
+
return ctx;
}
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index b6f1b7b4..c08912cf 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -28,40 +28,25 @@ namespace LLama
lparams.logits_all = @params.Perplexity;
lparams.embedding = @params.EmbeddingMode;
lparams.low_vram = @params.LowVram;
-
- if(@params.TensorSplits.Length != 1)
+
+ if (@params.TensorSplits.Length != 1)
{
throw new ArgumentException("Currently multi-gpu support is not supported by " +
"both llama.cpp and LLamaSharp.");
}
- lparams.tensor_split = new TensorSplits()
- {
- Item1 = @params.TensorSplits[0]
- };
+ lparams.tensor_split = @params.TensorSplits;
if (!File.Exists(@params.ModelPath))
{
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
}
- var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams);
-
- if (ctx_ptr == IntPtr.Zero)
- {
- throw new RuntimeError($"Failed to load model {@params.ModelPath}.");
- }
-
- SafeLLamaContextHandle ctx = new(ctx_ptr);
+ var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
+ var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.LoraAdapter))
- {
- int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter,
- string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads);
- if (err != 0)
- {
- throw new RuntimeError("Failed to apply lora adapter.");
- }
- }
+ model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
+
return ctx;
}
@@ -78,7 +63,7 @@ namespace LLama
return res.Take(n);
}
- public unsafe static Span GetLogits(SafeLLamaContextHandle ctx, int length)
+ public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length)
{
var logits = NativeApi.llama_get_logits(ctx);
return new Span(logits, length);