using System;
using System.Buffers;
using System.Text;
using LLama.Exceptions;
namespace LLama.Native
{
///
/// A safe wrapper around a llama_context
///
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
#region properties and fields
///
/// Total number of tokens in vocabulary of this model
///
public int VocabCount => ThrowIfDisposed().VocabCount;
///
/// Total number of tokens in the context
///
public int ContextSize => ThrowIfDisposed().ContextSize;
///
/// Dimension of embedding vectors
///
public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount;
///
/// This field guarantees that a reference to the model is held for as long as this handle is held
///
private SafeLlamaModelHandle? _model;
#endregion
#region construction/destruction
///
/// Create a new SafeLLamaContextHandle
///
/// pointer to an allocated llama_context
/// the model which this context was created from
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
// Increment the model reference count while this context exists
_model = model;
var success = false;
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}
///
protected override bool ReleaseHandle()
{
// Decrement refcount on model
_model?.DangerousRelease();
_model = null;
NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
private SafeLlamaModelHandle ThrowIfDisposed()
{
if (IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
if (_model == null || _model.IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed");
return _model;
}
///
/// Create a new llama_state for the given model
///
///
///
///
///
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
if (ctx_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create context from model");
return new(ctx_ptr, model);
}
#endregion
///
/// Convert the given text into tokens
///
/// The text to tokenize
/// Whether the "BOS" token should be added
/// Encoding to use for the text
///
///
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
ThrowIfDisposed();
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
return result;
}
finally
{
ArrayPool.Shared.Return(temporaryArray);
}
}
///
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab
///
///
///
public Span GetLogits()
{
var model = ThrowIfDisposed();
unsafe
{
var logits = NativeApi.llama_get_logits(this);
return new Span(logits, model.VocabCount);
}
}
///
/// Convert a token into a string
///
///
///
///
public string TokenToString(int token, Encoding encoding)
{
return ThrowIfDisposed().TokenToString(token, encoding);
}
///
/// Convert a token into a span of bytes that could be decoded into a string
///
///
///
public ReadOnlySpan TokenToSpan(int token)
{
return ThrowIfDisposed().TokenToSpan(token);
}
///
/// Run the llama inference to obtain the logits and probabilities for the next token.
///
/// The provided batch of new tokens to process
/// the number of tokens to use from previous eval calls
///
/// Returns true on success
public bool Eval(Memory tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
}
}
}
}