|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
|
|
|
|
namespace LLama.Native; |
|
|
|
|
|
|
|
@@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell |
|
|
|
/// <summary> |
|
|
|
/// An updateable view of the KV cache (llama_kv_cache_view) |
|
|
|
/// </summary> |
|
|
|
//todo: rewrite to safe handle? |
|
|
|
[StructLayout(LayoutKind.Sequential)] |
|
|
|
public unsafe struct LLamaKvCacheView |
|
|
|
{ |
|
|
|
@@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView |
|
|
|
LLamaSeqId* cells_sequences; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// A safe handle for a LLamaKvCacheView |
|
|
|
/// </summary> |
|
|
|
public class LLamaKvCacheViewSafeHandle |
|
|
|
: SafeLLamaHandleBase |
|
|
|
{ |
|
|
|
private readonly SafeLLamaContextHandle _ctx; |
|
|
|
private LLamaKvCacheView _view; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="view"></param> |
|
|
|
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view) |
|
|
|
: base(IntPtr.MaxValue, true) |
|
|
|
{ |
|
|
|
_ctx = ctx; |
|
|
|
_view = view; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Allocate a new llama_kv_cache_view_free |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param> |
|
|
|
/// <returns></returns> |
|
|
|
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) |
|
|
|
{ |
|
|
|
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences); |
|
|
|
return new LLamaKvCacheViewSafeHandle(ctx, result); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
protected override bool ReleaseHandle() |
|
|
|
{ |
|
|
|
NativeApi.llama_kv_cache_view_free(ref _view); |
|
|
|
SetHandle(IntPtr.Zero); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Update this view |
|
|
|
/// </summary> |
|
|
|
public void Update() |
|
|
|
{ |
|
|
|
NativeApi.llama_kv_cache_view_update(_ctx, ref _view); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Count the number of used cells in the KV cache |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
public int CountCells() |
|
|
|
{ |
|
|
|
return NativeApi.llama_get_kv_cache_used_cells(_ctx); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
public int CountTokens() |
|
|
|
{ |
|
|
|
return NativeApi.llama_get_kv_cache_token_count(_ctx); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the raw KV cache view |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
public ref LLamaKvCacheView GetView() |
|
|
|
{ |
|
|
|
return ref _view; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
partial class NativeApi |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
@@ -66,9 +144,8 @@ partial class NativeApi |
|
|
|
/// <summary> |
|
|
|
/// Free a KV cache view. (use only for debugging purposes) |
|
|
|
/// </summary> |
|
|
|
/// <param name="view"></param> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view); |
|
|
|
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) |
|
|
|
@@ -76,7 +153,7 @@ partial class NativeApi |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <param name="view"></param> |
|
|
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] |
|
|
|
public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view); |
|
|
|
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Returns the number of tokens in the KV cache (slow, use only for debug) |
|
|
|
|