|
|
|
@@ -1,72 +1,58 @@ |
|
|
|
using System; |
|
|
|
using System; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
|
|
|
|
namespace LLama.Native; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell) |
|
|
|
/// A safe handle for a LLamaKvCacheView |
|
|
|
/// </summary> |
|
|
|
[StructLayout(LayoutKind.Sequential)] |
|
|
|
public struct LLamaKvCacheViewCell |
|
|
|
public sealed class LLamaKvCacheViewSafeHandle |
|
|
|
: SafeLLamaHandleBase |
|
|
|
{ |
|
|
|
private readonly SafeLLamaContextHandle _ctx; |
|
|
|
private NativeLLamaKvCacheView _view; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The position for this cell. Takes KV cache shifts into account. |
|
|
|
/// May be negative if the cell is not populated. |
|
|
|
/// Number of KV cache cells. This will be the same as the context size. |
|
|
|
/// </summary> |
|
|
|
public LLamaPos pos; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// An updateable view of the KV cache (llama_kv_cache_view) |
|
|
|
/// </summary> |
|
|
|
[StructLayout(LayoutKind.Sequential)] |
|
|
|
public unsafe struct LLamaKvCacheView |
|
|
|
{ |
|
|
|
// Number of KV cache cells. This will be the same as the context size. |
|
|
|
int n_cells; |
|
|
|
|
|
|
|
// Maximum number of sequences that can exist in a cell. It's not an error |
|
|
|
// if there are more sequences in a cell than this value, however they will |
|
|
|
// not be visible in the view cells_sequences. |
|
|
|
int n_seq_max; |
|
|
|
|
|
|
|
// Number of tokens in the cache. For example, if there are two populated |
|
|
|
// cells, the first with 1 sequence id in it and the second with 2 sequence |
|
|
|
// ids then you'll have 3 tokens. |
|
|
|
int token_count; |
|
|
|
|
|
|
|
// Number of populated cache cells. |
|
|
|
int used_cells; |
|
|
|
public int CellCount => GetNativeView().n_cells; |
|
|
|
|
|
|
|
// Maximum contiguous empty slots in the cache. |
|
|
|
int max_contiguous; |
|
|
|
|
|
|
|
// Index to the start of the max_contiguous slot range. Can be negative |
|
|
|
// when cache is full. |
|
|
|
int max_contiguous_idx; |
|
|
|
|
|
|
|
// Information for an individual cell. |
|
|
|
LLamaKvCacheViewCell* cells; |
|
|
|
/// <summary> |
|
|
|
/// Get the total number of tokens in the KV cache. |
|
|
|
/// |
|
|
|
/// For example, if there are two populated |
|
|
|
/// cells, the first with 1 sequence id in it and the second with 2 sequence |
|
|
|
/// ids then you'll have 3 tokens. |
|
|
|
/// </summary> |
|
|
|
public int TokenCount => GetNativeView().token_count; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Maximum number of sequences visible for a cell. There may be more sequences than this |
|
|
|
/// in reality, this is simply the maximum number this view can see. |
|
|
|
/// </summary> |
|
|
|
public int MaxSequenceCount => GetNativeView().n_seq_max; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Number of populated cache cells |
|
|
|
/// </summary> |
|
|
|
public int UsedCellCount => GetNativeView().used_cells; |
|
|
|
|
|
|
|
// The sequences for each cell. There will be n_seq_max items per cell. |
|
|
|
LLamaSeqId* cells_sequences; |
|
|
|
} |
|
|
|
/// <summary> |
|
|
|
/// Maximum contiguous empty slots in the cache. |
|
|
|
/// </summary> |
|
|
|
public int MaxContiguous => GetNativeView().max_contiguous; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// A safe handle for a LLamaKvCacheView |
|
|
|
/// </summary> |
|
|
|
public class LLamaKvCacheViewSafeHandle |
|
|
|
: SafeLLamaHandleBase |
|
|
|
{ |
|
|
|
private readonly SafeLLamaContextHandle _ctx; |
|
|
|
private LLamaKvCacheView _view; |
|
|
|
/// <summary> |
|
|
|
/// Index to the start of the MaxContiguous slot range. Can be negative when cache is full. |
|
|
|
/// </summary> |
|
|
|
public int MaxContiguousIdx => GetNativeView().max_contiguous; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="view"></param> |
|
|
|
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view) |
|
|
|
private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view) |
|
|
|
: base((IntPtr)1, true) |
|
|
|
{ |
|
|
|
_ctx = ctx; |
|
|
|
@@ -81,76 +67,176 @@ public class LLamaKvCacheViewSafeHandle |
|
|
|
/// <returns></returns> |
|
|
|
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) |
|
|
|
{ |
|
|
|
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences); |
|
|
|
return new LLamaKvCacheViewSafeHandle(ctx, result); |
|
|
|
// Allocate the view |
|
|
|
var view = llama_kv_cache_view_init(ctx, maxSequences); |
|
|
|
var handle = new LLamaKvCacheViewSafeHandle(ctx, view); |
|
|
|
|
|
|
|
// Update the view so it has valid data after allocation. |
|
|
|
handle.Update(); |
|
|
|
|
|
|
|
return handle; |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
protected override bool ReleaseHandle() |
|
|
|
{ |
|
|
|
NativeApi.llama_kv_cache_view_free(ref _view); |
|
|
|
llama_kv_cache_view_free(ref _view); |
|
|
|
SetHandle(IntPtr.Zero); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Update this view |
|
|
|
/// Read the current KV cache state into this view. |
|
|
|
/// </summary> |
|
|
|
public void Update() |
|
|
|
{ |
|
|
|
NativeApi.llama_kv_cache_view_update(_ctx, ref _view); |
|
|
|
llama_kv_cache_view_update(_ctx, ref _view); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the raw KV cache view |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
public ref LLamaKvCacheView GetView() |
|
|
|
private ref NativeLLamaKvCacheView GetNativeView() |
|
|
|
{ |
|
|
|
if (IsClosed) |
|
|
|
throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed"); |
|
|
|
|
|
|
|
return ref _view; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public static partial class NativeApi |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
/// Get the cell at the given index |
|
|
|
/// </summary> |
|
|
|
/// <param name="index">The index of the cell [0, CellCount)</param> |
|
|
|
/// <returns>Data about the cell at the given index</returns> |
|
|
|
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 <= index < CellCount)</exception> |
|
|
|
public LLamaPos GetCell(int index) |
|
|
|
{ |
|
|
|
var view = GetNativeView(); |
|
|
|
|
|
|
|
if (index < 0) |
|
|
|
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); |
|
|
|
if (index >= view.n_cells) |
|
|
|
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); |
|
|
|
|
|
|
|
unsafe |
|
|
|
{ |
|
|
|
return view.cells[index].pos; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get all of the sequences assigned to the cell at the given index. This will contain <see cref="MaxSequenceCount"/> entries |
|
|
|
/// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter |
|
|
|
/// if necessary. Invalid sequences will be negative values. |
|
|
|
/// </summary> |
|
|
|
/// <param name="index">The index of the cell [0, CellCount)</param> |
|
|
|
/// <returns>A span containing the sequences assigned to this cell</returns> |
|
|
|
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 <= index < CellCount)</exception> |
|
|
|
public Span<LLamaSeqId> GetCellSequences(int index) |
|
|
|
{ |
|
|
|
var view = GetNativeView(); |
|
|
|
|
|
|
|
if (index < 0) |
|
|
|
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0"); |
|
|
|
if (index >= view.n_cells) |
|
|
|
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount"); |
|
|
|
|
|
|
|
unsafe |
|
|
|
{ |
|
|
|
return new Span<LLamaSeqId>(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#region native API |
|
|
|
/// <summary> |
|
|
|
/// Create an empty KV cache view. (use only for debugging purposes) |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="n_seq_max"></param> |
|
|
|
/// <returns></returns> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); |
|
|
|
|
|
|
|
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Free a KV cache view. (use only for debugging purposes) |
|
|
|
/// </summary> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view); |
|
|
|
|
|
|
|
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="view"></param> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view); |
|
|
|
|
|
|
|
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Returns the number of tokens in the KV cache (slow, use only for debug) |
|
|
|
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times |
|
|
|
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell) |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <returns></returns> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); |
|
|
|
|
|
|
|
[StructLayout(LayoutKind.Sequential)] |
|
|
|
private struct NativeLLamaKvCacheViewCell |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
/// The position for this cell. Takes KV cache shifts into account. |
|
|
|
/// May be negative if the cell is not populated. |
|
|
|
/// </summary> |
|
|
|
public LLamaPos pos; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) |
|
|
|
/// An updateable view of the KV cache (llama_kv_cache_view) |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <returns></returns> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx); |
|
|
|
[StructLayout(LayoutKind.Sequential)] |
|
|
|
private unsafe struct NativeLLamaKvCacheView |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
/// Number of KV cache cells. This will be the same as the context size. |
|
|
|
/// </summary> |
|
|
|
public int n_cells; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Maximum number of sequences that can exist in a cell. It's not an error |
|
|
|
/// if there are more sequences in a cell than this value, however they will |
|
|
|
/// not be visible in the view cells_sequences. |
|
|
|
/// </summary> |
|
|
|
public int n_seq_max; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Number of tokens in the cache. For example, if there are two populated |
|
|
|
/// cells, the first with 1 sequence id in it and the second with 2 sequence |
|
|
|
/// ids then you'll have 3 tokens. |
|
|
|
/// </summary> |
|
|
|
public int token_count; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Number of populated cache cells. |
|
|
|
/// </summary> |
|
|
|
public int used_cells; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Maximum contiguous empty slots in the cache. |
|
|
|
/// </summary> |
|
|
|
public int max_contiguous; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Index to the start of the max_contiguous slot range. Can be negative |
|
|
|
/// when cache is full. |
|
|
|
/// </summary> |
|
|
|
public int max_contiguous_idx; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Information for an individual cell. |
|
|
|
/// </summary> |
|
|
|
public NativeLLamaKvCacheViewCell* cells; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The sequences for each cell. There will be n_seq_max items per cell. |
|
|
|
/// </summary> |
|
|
|
public LLamaSeqId* cells_sequences; |
|
|
|
} |
|
|
|
#endregion |
|
|
|
} |