using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
///
/// A safe handle for a LLamaKvCacheView
///
public sealed class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
private NativeLLamaKvCacheView _view;
///
/// Number of KV cache cells. This will be the same as the context size.
///
public int CellCount => GetNativeView().n_cells;
///
/// Get the total number of tokens in the KV cache.
///
/// For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
///
public int TokenCount => GetNativeView().token_count;
///
/// Maximum number of sequences visible for a cell. There may be more sequences than this
/// in reality, this is simply the maximum number this view can see.
///
public int MaxSequenceCount => GetNativeView().n_seq_max;
///
/// Number of populated cache cells
///
public int UsedCellCount => GetNativeView().used_cells;
///
/// Maximum contiguous empty slots in the cache.
///
public int MaxContiguous => GetNativeView().max_contiguous;
///
/// Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
///
public int MaxContiguousIdx => GetNativeView().max_contiguous;
///
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
///
///
///
private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view)
: base((IntPtr)1, true)
{
_ctx = ctx;
_view = view;
}
///
/// Allocate a new KV cache view which can be used to inspect the KV cache
///
///
/// The maximum number of sequences visible in this view per cell
///
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{
// Allocate the view
var view = llama_kv_cache_view_init(ctx, maxSequences);
var handle = new LLamaKvCacheViewSafeHandle(ctx, view);
// Update the view so it has valid data after allocation.
handle.Update();
return handle;
}
///
protected override bool ReleaseHandle()
{
llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero);
return true;
}
///
/// Read the current KV cache state into this view.
///
public void Update()
{
llama_kv_cache_view_update(_ctx, ref _view);
}
///
/// Get the raw KV cache view
///
///
private ref NativeLLamaKvCacheView GetNativeView()
{
if (IsClosed)
throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed");
return ref _view;
}
///
/// Get the cell at the given index
///
/// The index of the cell [0, CellCount)
/// Data about the cell at the given index
/// Thrown if index is out of range (0 <= index < CellCount)
public LLamaPos GetCell(int index)
{
var view = GetNativeView();
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
unsafe
{
return view.cells[index].pos;
}
}
///
/// Get all of the sequences assigned to the cell at the given index. This will contain entries
/// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
/// if necessary. Invalid sequences will be negative values.
///
/// The index of the cell [0, CellCount)
/// A span containing the sequences assigned to this cell
/// Thrown if index is out of range (0 <= index < CellCount)
public Span GetCellSequences(int index)
{
var view = GetNativeView();
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
unsafe
{
return new Span(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max);
}
}
#region native API
///
/// Create an empty KV cache view. (use only for debugging purposes)
///
///
///
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
///
/// Free a KV cache view. (use only for debugging purposes)
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view);
///
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
///
///
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
///
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
///
[StructLayout(LayoutKind.Sequential)]
private struct NativeLLamaKvCacheViewCell
{
///
/// The position for this cell. Takes KV cache shifts into account.
/// May be negative if the cell is not populated.
///
public LLamaPos pos;
}
///
/// An updateable view of the KV cache (llama_kv_cache_view)
///
[StructLayout(LayoutKind.Sequential)]
private unsafe struct NativeLLamaKvCacheView
{
///
/// Number of KV cache cells. This will be the same as the context size.
///
public int n_cells;
///
/// Maximum number of sequences that can exist in a cell. It's not an error
/// if there are more sequences in a cell than this value, however they will
/// not be visible in the view cells_sequences.
///
public int n_seq_max;
///
/// Number of tokens in the cache. For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
///
public int token_count;
///
/// Number of populated cache cells.
///
public int used_cells;
///
/// Maximum contiguous empty slots in the cache.
///
public int max_contiguous;
///
/// Index to the start of the max_contiguous slot range. Can be negative
/// when cache is full.
///
public int max_contiguous_idx;
///
/// Information for an individual cell.
///
public NativeLLamaKvCacheViewCell* cells;
///
/// The sequences for each cell. There will be n_seq_max items per cell.
///
public LLamaSeqId* cells_sequences;
}
#endregion
}