using System;
using System.Text;
using LLama.Exceptions;
namespace LLama.Native
{
///
/// A reference to a set of llama model weights
///
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
///
/// Total number of tokens in vocabulary of this model
///
public int VocabCount { get; }
///
/// Total number of tokens in the context
///
public int ContextSize { get; }
///
/// Dimension of embedding vectors
///
public int EmbeddingSize { get; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_n_vocab_from_model(this);
ContextSize = NativeApi.llama_n_ctx_from_model(this);
EmbeddingSize = NativeApi.llama_n_embd_from_model(this);
}
///
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(handle);
SetHandle(IntPtr.Zero);
return true;
}
///
/// Load a model from the given file path into memory
///
///
///
///
///
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
throw new RuntimeError($"Failed to load model {modelPath}.");
return new SafeLlamaModelHandle(model_ptr);
}
#region LoRA
///
/// Apply a LoRA adapter to a loaded model
///
///
/// A path to a higher quality model to use as a base for the layers modified by the
/// adapter. Can be NULL to use the current loaded model.
///
///
public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
threads
);
if (err != 0)
throw new RuntimeError("Failed to apply lora adapter.");
}
#endregion
#region tokenize
///
/// Convert a single llama token into string bytes
///
///
///
public ReadOnlySpan TokenToSpan(int llama_token)
{
unsafe
{
var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
var terminator = bytes.IndexOf((byte)0);
return bytes.Slice(0, terminator);
}
}
///
/// Convert a single llama token into a string
///
///
/// Encoding to use to decode the bytes into a string
///
public string TokenToString(int llama_token, Encoding encoding)
{
var span = TokenToSpan(llama_token);
if (span.Length == 0)
return "";
unsafe
{
fixed (byte* ptr = &span[0])
{
return encoding.GetString(ptr, span.Length);
}
}
}
///
/// Convert a string of text into tokens
///
///
///
///
///
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = new byte[bytesCount + 1];
unsafe
{
fixed (char* charPtr = text)
fixed (byte* bytePtr = &bytes[0])
{
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
}
}
unsafe
{
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0])
{
NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
return tokens;
}
}
}
}
#endregion
#region context
///
/// Create a new context for this model
///
///
///
public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
{
return SafeLLamaContextHandle.Create(this, @params);
}
#endregion
}
}