From c5146bac2321033ee9b70c5789d8d2b878fb819b Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 7 Feb 2024 16:35:39 +0000 Subject: [PATCH] - Exposed KV debug view through `SafeLLamaContextHandle` - Added `KvCacheSequenceDivide` - Moved count tokens/cells methods to `SafeLLamaContextHandle` --- LLama/Native/LLamaKvCacheView.cs | 20 +----------- LLama/Native/SafeLLamaContextHandle.cs | 44 ++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 4cccd13c..2b408772 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -74,7 +74,7 @@ public class LLamaKvCacheViewSafeHandle } /// - /// Allocate a new llama_kv_cache_view_free + /// Allocate a new KV cache view which can be used to inspect the KV cache /// /// /// The maximum number of sequences visible in this view per cell @@ -102,24 +102,6 @@ public class LLamaKvCacheViewSafeHandle NativeApi.llama_kv_cache_view_update(_ctx, ref _view); } - /// - /// Count the number of used cells in the KV cache - /// - /// - public int CountCells() - { - return NativeApi.llama_get_kv_cache_used_cells(_ctx); - } - - /// - /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times - /// - /// - public int CountTokens() - { - return NativeApi.llama_get_kv_cache_token_count(_ctx); - } - /// /// Get the raw KV cache view /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index d90d46d5..91e82c85 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -289,6 +289,35 @@ namespace LLama.Native } #region KV Cache Management + /// + /// Get a new KV cache view that can be used to debug the KV cache + /// + /// + /// + public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4) + { + return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences); + } + + /// + /// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them) + /// + /// + public int KvCacheCountCells() + { + return NativeApi.llama_get_kv_cache_used_cells(this); + } + + /// + /// Returns the number of tokens in the KV cache (slow, use only for debug) + /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + /// + /// + public int KvCacheCountTokens() + { + return NativeApi.llama_get_kv_cache_token_count(this); + } + /// /// Clear the KV cache /// @@ -344,6 +373,21 @@ namespace LLama.Native { NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); } + + /// + /// Integer division of the positions by factor of `d > 1` + /// If the KV cache is RoPEd, the KV data is updated accordingly + /// p0 < 0 : [0, p1] + /// p1 < 0 : [p0, inf) + /// + /// + /// + /// + /// + public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor) + { + NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor); + } #endregion } }