using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; namespace LLama.Native { /// /// A safe wrapper around a llama_context /// // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => ThrowIfDisposed().VocabCount; /// /// Total number of tokens in the context /// public uint ContextSize => NativeApi.llama_n_ctx(this); /// /// Dimension of embedding vectors /// public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// /// Get the maximum batch size for this context /// public uint BatchSize => NativeApi.llama_n_batch(this); /// /// Get the model which this context is using /// public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); private SafeLlamaModelHandle? _model; #endregion #region construction/destruction /// protected override bool ReleaseHandle() { llama_free(handle); SetHandle(IntPtr.Zero); // Decrement refcount on model _model?.DangerousRelease(); _model = null!; return true; } private SafeLlamaModelHandle ThrowIfDisposed() { if (IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); return _model!; } /// /// Create a new llama_state for the given model /// /// /// /// /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { var ctx = llama_new_context_with_model(model, lparams); if (ctx == null) throw new RuntimeError("Failed to create context from model"); // Increment the model reference count while this context exists. // DangerousAddRef throws if it fails, so there is no need to check "success" ctx._model = model; var success = false; ctx._model.DangerousAddRef(ref success); return ctx; } #endregion #region Native API static SafeLLamaContextHandle() { // This ensures that `NativeApi` has been loaded before calling the two native methods below NativeApi.llama_empty_call(); } /// /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**! /// /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// /// Frees all allocated memory in the given llama_context /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_free(IntPtr ctx); /// /// Set a callback which can abort computation /// /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern unsafe void llama_set_abort_callback(SafeLLamaContextHandle ctx, GgmlAbortCallback abort_callback, void* abort_callback_data); /// /// If this returns true computation is cancelled /// /// /// private unsafe delegate bool GgmlAbortCallback(void* data); #endregion /// /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
/// public Span GetLogits() { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits(this); return new Span(logits, model.VocabCount); } } /// /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// /// public Span GetLogitsIth(int i) { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits_ith(this, i); return new Span(logits, model.VocabCount); } } #region tokens /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding); } /// /// Convert a single llama token into bytes /// /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` public uint TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } #endregion #region infer /// /// This object exists to ensure there is only ever 1 inference running at a time. This is a workaround for thread safety issues in llama.cpp itself. /// Most notably CUDA, which seems to use some global singleton resources and will crash if multiple inferences are run (even against different models). /// /// For more information see these issues: /// - https://github.com/SciSharp/LLamaSharp/issues/596 /// - https://github.com/ggerganov/llama.cpp/issues/3960 /// /// If these are ever resolved this lock can probably be removed. /// private static readonly object GlobalInferenceLock = new(); /// /// /// /// Positive return values does not mean a fatal error, but rather a warning:
/// - 0: success
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
///
public DecodeResult Decode(LLamaBatch batch) { lock (GlobalInferenceLock) using (batch.ToNativeBatch(out var nb)) return (DecodeResult)NativeApi.llama_decode(this, nb); } /// /// Decode a set of tokens in batch-size chunks. /// /// /// /// /// /// A tuple, containing the decode result and the number of tokens that have not been decoded yet. internal (DecodeResult, int) Decode(List tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past) { var batchSize = checked((int)BatchSize); // Evaluate the prompt, in chunks smaller than the max batch size var n_left = tokens.Count; for (var i = 0; i < tokens.Count; i += batchSize) { var n_eval = tokens.Count - i; if (n_eval > batchSize) n_eval = batchSize; batch.Clear(); for (var j = 0; j < n_eval; j++) batch.Add(tokens[i + j], n_past++, id, (i + j) == tokens.Count - 1); var returnCode = Decode(batch); if (returnCode != DecodeResult.Ok) return (returnCode, n_left); n_left -= n_eval; } return (DecodeResult.Ok, 0); } #endregion #region state /// /// Get the size of the state, when saved as bytes /// public ulong GetStateSize() { return NativeApi.llama_get_state_size(this); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public unsafe ulong GetState(byte* dest, ulong size) { return GetState(new IntPtr(dest), size); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public ulong GetState(IntPtr dest, ulong size) { var required = GetStateSize(); if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); unsafe { return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); } } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public unsafe ulong SetState(byte* src) { return SetState(new IntPtr(src)); } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public ulong SetState(IntPtr src) { unsafe { return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); } } #endregion /// /// Set the RNG seed /// /// public void SetSeed(uint seed) { NativeApi.llama_set_rng_seed(this, seed); } /// /// Set the number of threads used for decoding /// /// n_threads is the number of threads used for generation (single token) /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) public void SetThreads(uint threads, uint threadsBatch) { NativeApi.llama_set_n_threads(this, threads, threadsBatch); } #region KV Cache Management /// /// Get a new KV cache view that can be used to debug the KV cache /// /// /// public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4) { return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences); } /// /// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them) /// /// public int KvCacheCountCells() { return NativeApi.llama_get_kv_cache_used_cells(this); } /// /// Returns the number of tokens in the KV cache (slow, use only for debug) /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times /// /// public int KvCacheCountTokens() { return NativeApi.llama_get_kv_cache_token_count(this); } /// /// Clear the KV cache /// public void KvCacheClear() { NativeApi.llama_kv_cache_clear(this); } /// /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) /// /// /// /// public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1) { NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1); } /// /// Copy all tokens that belong to the specified sequence to another sequence. Note that /// this does not allocate extra KV cache memory - it simply assigns the tokens to the /// new sequence /// /// /// /// /// public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1) { NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1); } /// /// Removes all tokens that do not belong to the specified sequence /// /// public void KvCacheSequenceKeep(LLamaSeqId seq) { NativeApi.llama_kv_cache_seq_keep(this, seq); } /// /// Adds relative position "delta" to all tokens that belong to the specified sequence /// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated /// accordingly /// /// /// /// /// public void KvCacheSequenceAdd(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) { NativeApi.llama_kv_cache_seq_add(this, seq, p0, p1, delta); } /// /// Integer division of the positions by factor of `d > 1`. /// If the KV cache is RoPEd, the KV data is updated accordingly.
/// p0 < 0 : [0, p1]
/// p1 < 0 : [p0, inf) ///
/// /// /// /// public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor) { NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor); } #endregion } }