Browse Source

- Exposed KV debug view through `SafeLLamaContextHandle`

- Added `KvCacheSequenceDivide`
 - Moved count tokens/cells methods to `SafeLLamaContextHandle`
tags/v0.10.0
Martin Evans 1 year ago
parent
commit
c5146bac23
2 changed files with 45 additions and 19 deletions
  1. +1
    -19
      LLama/Native/LLamaKvCacheView.cs
  2. +44
    -0
      LLama/Native/SafeLLamaContextHandle.cs

+ 1
- 19
LLama/Native/LLamaKvCacheView.cs View File

@@ -74,7 +74,7 @@ public class LLamaKvCacheViewSafeHandle
}

/// <summary>
/// Allocate a new llama_kv_cache_view_free
/// Allocate a new KV cache view which can be used to inspect the KV cache
/// </summary>
/// <param name="ctx"></param>
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
@@ -102,24 +102,6 @@ public class LLamaKvCacheViewSafeHandle
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
}

/// <summary>
/// Count the number of used cells in the KV cache
/// </summary>
/// <returns></returns>
public int CountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(_ctx);
}

/// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times
/// </summary>
/// <returns></returns>
public int CountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(_ctx);
}

/// <summary>
/// Get the raw KV cache view
/// </summary>


+ 44
- 0
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -289,6 +289,35 @@ namespace LLama.Native
}

#region KV Cache Management
/// <summary>
/// Get a new KV cache view that can be used to debug the KV cache
/// </summary>
/// <param name="maxSequences"></param>
/// <returns></returns>
public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4)
{
return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences);
}

/// <summary>
/// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them)
/// </summary>
/// <returns></returns>
public int KvCacheCountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(this);
}

/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// </summary>
/// <returns></returns>
public int KvCacheCountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(this);
}

/// <summary>
/// Clear the KV cache
/// </summary>
@@ -344,6 +373,21 @@ namespace LLama.Native
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}

/// <summary>
/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// p0 &lt; 0 : [0, p1]
/// p1 &lt; 0 : [p0, inf)
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="divisor"></param>
public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor)
{
NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor);
}
#endregion
}
}

Loading…
Cancel
Save