using System; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; namespace LLama.Native { /// /// A safe wrapper around a llama_context /// // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => ThrowIfDisposed().VocabCount; /// /// Total number of tokens in the context /// public uint ContextSize => NativeApi.llama_n_ctx(this); /// /// Dimension of embedding vectors /// public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// /// Get the model which this context is using /// public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); private SafeLlamaModelHandle? _model; #endregion #region construction/destruction /// protected override bool ReleaseHandle() { llama_free(handle); SetHandle(IntPtr.Zero); // Decrement refcount on model _model?.DangerousRelease(); _model = null!; return true; } private SafeLlamaModelHandle ThrowIfDisposed() { if (IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); return _model!; } /// /// Create a new llama_state for the given model /// /// /// /// /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { var ctx = llama_new_context_with_model(model, lparams); if (ctx == null) throw new RuntimeError("Failed to create context from model"); // Increment the model reference count while this context exists. // DangerousAddRef throws if it fails, so there is no need to check "success" ctx._model = model; var success = false; ctx._model.DangerousAddRef(ref success); return ctx; } #endregion #region Native API static SafeLLamaContextHandle() { // This ensures that `NativeApi` has been loaded before calling the two native methods below NativeApi.llama_empty_call(); } /// /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**! /// /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// /// Frees all allocated memory in the given llama_context /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_free(IntPtr ctx); #endregion /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
/// public Span GetLogits() { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits(this); return new Span(logits, model.VocabCount); } } /// /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// /// public Span GetLogitsIth(int i) { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits_ith(this, i); return new Span(logits, model.VocabCount); } } #region tokens /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding); } /// /// Convert a single llama token into bytes /// /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` public uint TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } #endregion #region infer /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// /// The provided batch of new tokens to process /// the number of tokens to use from previous eval calls /// Returns true on success [Obsolete("use llama_decode() instead")] public bool Eval(ReadOnlySpan tokens, int n_past) { unsafe { fixed (LLamaToken* pinned = tokens) { // the entire `eval` system needs replacing with the new batch system! var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); return ret == 0; } } } /// /// /// /// Positive return values does not mean a fatal error, but rather a warning:
/// - 0: success
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
///
public int Decode(LLamaBatch batch) { using (batch.ToNativeBatch(out var nb)) return NativeApi.llama_decode(this, nb); } #endregion #region state /// /// Get the size of the state, when saved as bytes /// public ulong GetStateSize() { return NativeApi.llama_get_state_size(this); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public unsafe ulong GetState(byte* dest, ulong size) { return GetState(new IntPtr(dest), size); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public ulong GetState(IntPtr dest, ulong size) { var required = GetStateSize(); if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); unsafe { return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); } } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public unsafe ulong SetState(byte* src) { return SetState(new IntPtr(src)); } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public ulong SetState(IntPtr src) { unsafe { return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); } } #endregion /// /// Set the RNG seed /// /// public void SetSeed(uint seed) { NativeApi.llama_set_rng_seed(this, seed); } /// /// Set the number of threads used for decoding /// /// n_threads is the number of threads used for generation (single token) /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) public void SetThreads(uint threads, uint threadsBatch) { NativeApi.llama_set_n_threads(this, threads, threadsBatch); } #region KV Cache Management /// /// Get a new KV cache view that can be used to debug the KV cache /// /// /// public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4) { return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences); } /// /// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them) /// /// public int KvCacheCountCells() { return NativeApi.llama_get_kv_cache_used_cells(this); } /// /// Returns the number of tokens in the KV cache (slow, use only for debug) /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times /// /// public int KvCacheCountTokens() { return NativeApi.llama_get_kv_cache_token_count(this); } /// /// Clear the KV cache /// public void KvCacheClear() { NativeApi.llama_kv_cache_clear(this); } /// /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) /// /// /// /// public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1) { NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1); } /// /// Copy all tokens that belong to the specified sequence to another sequence. Note that /// this does not allocate extra KV cache memory - it simply assigns the tokens to the /// new sequence /// /// /// /// /// public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1) { NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1); } /// /// Removes all tokens that do not belong to the specified sequence /// /// public void KvCacheSequenceKeep(LLamaSeqId seq) { NativeApi.llama_kv_cache_seq_keep(this, seq); } /// /// Adds relative position "delta" to all tokens that belong to the specified sequence /// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated /// accordingly /// /// /// /// /// public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) { NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); } /// /// Integer division of the positions by factor of `d > 1`. /// If the KV cache is RoPEd, the KV data is updated accordingly.
/// p0 < 0 : [0, p1]
/// p1 < 0 : [p0, inf) ///
/// /// /// /// public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor) { NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor); } #endregion } }