using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
namespace LLama.Native
{
///
/// A reference to a set of llama model weights
///
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
///
/// Total number of tokens in vocabulary of this model
///
public int VocabCount { get; }
///
/// Total number of tokens in the context
///
public int ContextSize { get; }
///
/// Dimension of embedding vectors
///
public int EmbeddingSize { get; }
///
/// Get the size of this model in bytes
///
public ulong SizeInBytes { get; }
///
/// Get the number of parameters in this model
///
public ulong ParameterCount { get; }
///
/// Get the number of metadata key/value pairs
///
///
public int MetadataCount { get; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_n_vocab(this);
ContextSize = NativeApi.llama_n_ctx_train(this);
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
MetadataCount = NativeApi.llama_model_meta_count(this);
}
///
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(DangerousGetHandle());
SetHandle(IntPtr.Zero);
return true;
}
///
/// Load a model from the given file path into memory
///
///
///
///
///
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
throw new RuntimeError($"Failed to load model {modelPath}.");
return new SafeLlamaModelHandle(model_ptr);
}
#region LoRA
///
/// Apply a LoRA adapter to a loaded model
///
///
///
/// A path to a higher quality model to use as a base for the layers modified by the
/// adapter. Can be NULL to use the current loaded model.
///
///
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
threads ?? Math.Max(1, Environment.ProcessorCount / 2)
);
if (err != 0)
throw new RuntimeError("Failed to apply lora adapter.");
}
#endregion
#region tokenize
///
/// Convert a single llama token into bytes
///
/// Token to decode
/// A span to attempt to write into. If this is too small nothing will be written
/// The size of this token. **nothing will be written** if this is larger than `dest`
public int TokenToSpan(int llama_token, Span dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
return Math.Abs(length);
}
}
}
///
/// Convert a sequence of tokens into characters.
///
///
///
///
/// The section of the span which has valid data in it.
/// If there was insufficient space in the output span this will be
/// filled with as many characters as possible, starting from the _last_ token.
///
[Obsolete("Use a StreamingTokenDecoder instead")]
internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding)
{
var decoder = new StreamingTokenDecoder(encoding, this);
foreach (var token in tokens)
decoder.Add(token);
var str = decoder.Read();
if (str.Length < dest.Length)
{
str.AsSpan().CopyTo(dest);
return dest.Slice(0, str.Length);
}
else
{
str.AsSpan().Slice(str.Length - dest.Length).CopyTo(dest);
return dest;
}
}
///
/// Convert a string of text into tokens
///
///
///
///
/// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
///
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = new byte[bytesCount + 1];
unsafe
{
fixed (char* charPtr = text)
fixed (byte* bytePtr = &bytes[0])
{
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
}
}
unsafe
{
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0])
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
}
}
}
}
#endregion
#region context
///
/// Create a new context for this model
///
///
///
public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
{
return SafeLLamaContextHandle.Create(this, @params);
}
#endregion
#region metadata
///
/// Get the metadata key for the given index
///
/// The index to get
/// A temporary buffer to store key characters in. Must be large enough to contain the key.
/// The key, null if there is no such key or if the buffer was too small
public Memory? MetadataKeyByIndex(int index)
{
int keyLength;
unsafe
{
// Check if the key exists, without getting any bytes of data
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)0, 0);
if (keyLength < 0)
return null;
}
// get a buffer large enough to hold it
var buffer = new byte[keyLength + 1];
unsafe
{
using var pin = buffer.AsMemory().Pin();
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
Debug.Assert(keyLength >= 0);
return buffer.AsMemory().Slice(0, keyLength);
}
}
///
/// Get the metadata value for the given index
///
/// The index to get
/// A temporary buffer to store value characters in. Must be large enough to contain the value.
/// The value, null if there is no such value or if the buffer was too small
public Memory? MetadataValueByIndex(int index)
{
int valueLength;
unsafe
{
// Check if the key exists, without getting any bytes of data
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)0, 0);
if (valueLength < 0)
return null;
}
// get a buffer large enough to hold it
var buffer = new byte[valueLength + 1];
unsafe
{
using var pin = buffer.AsMemory().Pin();
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
Debug.Assert(valueLength >= 0);
return buffer.AsMemory().Slice(0, valueLength);
}
}
internal IReadOnlyDictionary ReadMetadata()
{
var result = new Dictionary();
for (var i = 0; i < MetadataCount; i++)
{
var keyBytes = MetadataKeyByIndex(i);
if (keyBytes == null)
continue;
var key = Encoding.UTF8.GetStringFromSpan(keyBytes.Value.Span);
var valBytes = MetadataValueByIndex(i);
if (valBytes == null)
continue;
var val = Encoding.UTF8.GetStringFromSpan(valBytes.Value.Span);
result[key] = val;
}
return result;
}
#endregion
}
}