using System; using System.Buffers; using System.Text; using LLama.Exceptions; namespace LLama.Native { /// /// A safe wrapper around a llama_context /// public class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => ThrowIfDisposed().VocabCount; /// /// Total number of tokens in the context /// public int ContextSize => ThrowIfDisposed().ContextSize; /// /// Dimension of embedding vectors /// public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; /// /// This field guarantees that a reference to the model is held for as long as this handle is held /// private SafeLlamaModelHandle? _model; #endregion #region construction/destruction /// /// Create a new SafeLLamaContextHandle /// /// pointer to an allocated llama_context /// the model which this context was created from public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) : base(handle) { // Increment the model reference count while this context exists _model = model; var success = false; _model.DangerousAddRef(ref success); if (!success) throw new RuntimeError("Failed to increment model refcount"); } /// protected override bool ReleaseHandle() { // Decrement refcount on model _model?.DangerousRelease(); _model = null; NativeApi.llama_free(handle); SetHandle(IntPtr.Zero); return true; } private SafeLlamaModelHandle ThrowIfDisposed() { if (IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); return _model; } /// /// Create a new llama_state for the given model /// /// /// /// /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); if (ctx_ptr == IntPtr.Zero) throw new RuntimeError("Failed to create context from model"); return new(ctx_ptr, model); } #endregion /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text /// /// public int[] Tokenize(string text, bool add_bos, Encoding encoding) { ThrowIfDisposed(); // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't // possibly be more than this. var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); // "Rent" an array to write results into (avoiding an allocation of a large array) var temporaryArray = ArrayPool.Shared.Rent(count); try { // Do the actual conversion var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); if (n < 0) { throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + "specify the encoding."); } // Copy the results from the rented into an array which is exactly the right size var result = new int[n]; Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); return result; } finally { ArrayPool.Shared.Return(temporaryArray); } } /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
/// /// public Span GetLogits() { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits(this); return new Span(logits, model.VocabCount); } } /// /// Convert a token into a string /// /// /// /// public string TokenToString(int token, Encoding encoding) { return ThrowIfDisposed().TokenToString(token, encoding); } /// /// Convert a token into a span of bytes that could be decoded into a string /// /// /// public ReadOnlySpan TokenToSpan(int token) { return ThrowIfDisposed().TokenToSpan(token); } /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// /// The provided batch of new tokens to process /// the number of tokens to use from previous eval calls /// /// Returns true on success public bool Eval(Memory tokens, int n_past, int n_threads) { using var pin = tokens.Pin(); unsafe { return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; } } } }