using System; using System.Buffers; using System.Collections.Generic; using System.Text; using LLama.Exceptions; namespace LLama.Native { /// /// A safe wrapper around a llama_context /// public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => ThrowIfDisposed().VocabCount; /// /// Total number of tokens in the context /// public int ContextSize => NativeApi.llama_n_ctx(this); /// /// Dimension of embedding vectors /// public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// /// Get the model which this context is using /// public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); private SafeLlamaModelHandle? _model; #endregion #region construction/destruction /// /// Create a new SafeLLamaContextHandle /// /// pointer to an allocated llama_context /// the model which this context was created from public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) : base(handle) { // Increment the model reference count while this context exists _model = model; var success = false; _model.DangerousAddRef(ref success); if (!success) throw new RuntimeError("Failed to increment model refcount"); } /// protected override bool ReleaseHandle() { NativeApi.llama_free(DangerousGetHandle()); SetHandle(IntPtr.Zero); // Decrement refcount on model _model?.DangerousRelease(); _model = null!; return true; } private SafeLlamaModelHandle ThrowIfDisposed() { if (IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); return _model!; } /// /// Create a new llama_state for the given model /// /// /// /// /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); if (ctx_ptr == IntPtr.Zero) throw new RuntimeError("Failed to create context from model"); return new(ctx_ptr, model); } #endregion /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
/// public Span GetLogits() { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits(this); return new Span(logits, model.VocabCount); } } /// /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// /// public Span GetLogitsIth(int i) { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits_ith(this, i); return new Span(logits, model.VocabCount); } } #region tokens /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { ThrowIfDisposed(); if (string.IsNullOrEmpty(text) && !add_bos) return Array.Empty(); // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't // possibly be more than this. var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); // "Rent" an array to write results into (avoiding an allocation of a large array) var temporaryArray = ArrayPool.Shared.Rent(count); try { // Do the actual conversion var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special); if (n < 0) { throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + "specify the encoding."); } // Copy the results from the rented into an array which is exactly the right size var result = new int[n]; Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); return result; } finally { ArrayPool.Shared.Return(temporaryArray); } } /// /// Convert a single llama token into bytes /// /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` public int TokenToSpan(int token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } #endregion /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// /// The provided batch of new tokens to process /// the number of tokens to use from previous eval calls /// Returns true on success [Obsolete("use llama_decode() instead")] public bool Eval(ReadOnlySpan tokens, int n_past) { unsafe { fixed (int* pinned = tokens) { // the entire `eval` system needs replacing with the new batch system! var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); return ret == 0; } } } /// /// /// /// Positive return values does not mean a fatal error, but rather a warning:
/// - 0: success
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
///
public int Decode(LLamaBatchSafeHandle batch) { return NativeApi.llama_decode(this, batch.NativeBatch); } #region state /// /// Get the size of the state, when saved as bytes /// public ulong GetStateSize() { return NativeApi.llama_get_state_size(this); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public unsafe ulong GetState(byte* dest, ulong size) { return GetState(new IntPtr(dest), size); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public ulong GetState(IntPtr dest, ulong size) { var required = GetStateSize(); if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); unsafe { return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); } } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public unsafe ulong SetState(byte* src) { return SetState(new IntPtr(src)); } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public ulong SetState(IntPtr src) { unsafe { return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); } } #endregion } }