using System; using System.Runtime.InteropServices; namespace LLama.Native; /// /// A safe handle for a LLamaKvCacheView /// public sealed class LLamaKvCacheViewSafeHandle : SafeLLamaHandleBase { private readonly SafeLLamaContextHandle _ctx; private NativeLLamaKvCacheView _view; /// /// Number of KV cache cells. This will be the same as the context size. /// public int CellCount => GetNativeView().n_cells; /// /// Get the total number of tokens in the KV cache. /// /// For example, if there are two populated /// cells, the first with 1 sequence id in it and the second with 2 sequence /// ids then you'll have 3 tokens. /// public int TokenCount => GetNativeView().token_count; /// /// Maximum number of sequences visible for a cell. There may be more sequences than this /// in reality, this is simply the maximum number this view can see. /// public int MaxSequenceCount => GetNativeView().n_seq_max; /// /// Number of populated cache cells /// public int UsedCellCount => GetNativeView().used_cells; /// /// Maximum contiguous empty slots in the cache. /// public int MaxContiguous => GetNativeView().max_contiguous; /// /// Index to the start of the MaxContiguous slot range. Can be negative when cache is full. /// public int MaxContiguousIdx => GetNativeView().max_contiguous; /// /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed /// /// /// private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view) : base((IntPtr)1, true) { _ctx = ctx; _view = view; } /// /// Allocate a new KV cache view which can be used to inspect the KV cache /// /// /// The maximum number of sequences visible in this view per cell /// public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) { // Allocate the view var view = llama_kv_cache_view_init(ctx, maxSequences); var handle = new LLamaKvCacheViewSafeHandle(ctx, view); // Update the view so it has valid data after allocation. handle.Update(); return handle; } /// protected override bool ReleaseHandle() { llama_kv_cache_view_free(ref _view); SetHandle(IntPtr.Zero); return true; } /// /// Read the current KV cache state into this view. /// public void Update() { llama_kv_cache_view_update(_ctx, ref _view); } /// /// Get the raw KV cache view /// /// private ref NativeLLamaKvCacheView GetNativeView() { if (IsClosed) throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed"); return ref _view; } /// /// Get the cell at the given index /// /// The index of the cell [0, CellCount) /// Data about the cell at the given index /// Thrown if index is out of range (0 <= index < CellCount) public LLamaPos GetCell(int index) { var view = GetNativeView(); if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); if (index >= view.n_cells) throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); unsafe { return view.cells[index].pos; } } /// /// Get all of the sequences assigned to the cell at the given index. This will contain entries /// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter /// if necessary. Invalid sequences will be negative values. /// /// The index of the cell [0, CellCount) /// A span containing the sequences assigned to this cell /// Thrown if index is out of range (0 <= index < CellCount) public Span GetCellSequences(int index) { var view = GetNativeView(); if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); if (index >= view.n_cells) throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); unsafe { return new Span(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max); } } #region native API /// /// Create an empty KV cache view. (use only for debugging purposes) /// /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); /// /// Free a KV cache view. (use only for debugging purposes) /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view); /// /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view); /// /// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell) /// [StructLayout(LayoutKind.Sequential)] private struct NativeLLamaKvCacheViewCell { /// /// The position for this cell. Takes KV cache shifts into account. /// May be negative if the cell is not populated. /// public LLamaPos pos; } /// /// An updateable view of the KV cache (llama_kv_cache_view) /// [StructLayout(LayoutKind.Sequential)] private unsafe struct NativeLLamaKvCacheView { /// /// Number of KV cache cells. This will be the same as the context size. /// public int n_cells; /// /// Maximum number of sequences that can exist in a cell. It's not an error /// if there are more sequences in a cell than this value, however they will /// not be visible in the view cells_sequences. /// public int n_seq_max; /// /// Number of tokens in the cache. For example, if there are two populated /// cells, the first with 1 sequence id in it and the second with 2 sequence /// ids then you'll have 3 tokens. /// public int token_count; /// /// Number of populated cache cells. /// public int used_cells; /// /// Maximum contiguous empty slots in the cache. /// public int max_contiguous; /// /// Index to the start of the max_contiguous slot range. Can be negative /// when cache is full. /// public int max_contiguous_idx; /// /// Information for an individual cell. /// public NativeLLamaKvCacheViewCell* cells; /// /// The sequences for each cell. There will be n_seq_max items per cell. /// public LLamaSeqId* cells_sequences; } #endregion }