Browse Source

Added a safe handle for LLamaKvCacheView

tags/0.9.1
Martin Evans 1 year ago
parent
commit
bab6b65b61
2 changed files with 87 additions and 6 deletions
  1. +5
    -1
      LLama.Examples/Program.cs
  2. +82
    -5
      LLama/Native/LLamaKvCacheView.cs

+ 5
- 1
LLama.Examples/Program.cs View File

@@ -7,7 +7,11 @@ Console.WriteLine(" __ __ ____ _

Console.WriteLine("======================================================================================================");

NativeLibraryConfig.Instance.WithCuda().WithLogs();
NativeLibraryConfig
.Instance
.WithCuda()
.WithLogs()
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);

NativeApi.llama_empty_call();
Console.WriteLine();


+ 82
- 5
LLama/Native/LLamaKvCacheView.cs View File

@@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System;
using System.Runtime.InteropServices;

namespace LLama.Native;

@@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell
/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary>
//todo: rewrite to safe handle?
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaKvCacheView
{
@@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView
LLamaSeqId* cells_sequences;
}

/// <summary>
/// A safe handle for a LLamaKvCacheView
/// </summary>
public class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
private LLamaKvCacheView _view;

/// <summary>
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
/// </summary>
/// <param name="ctx"></param>
/// <param name="view"></param>
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
: base(IntPtr.MaxValue, true)
{
_ctx = ctx;
_view = view;
}

/// <summary>
/// Allocate a new llama_kv_cache_view_free
/// </summary>
/// <param name="ctx"></param>
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
/// <returns></returns>
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
return new LLamaKvCacheViewSafeHandle(ctx, result);
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero);

return true;
}

/// <summary>
/// Update this view
/// </summary>
public void Update()
{
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
}

/// <summary>
/// Count the number of used cells in the KV cache
/// </summary>
/// <returns></returns>
public int CountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(_ctx);
}

/// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
/// </summary>
/// <returns></returns>
public int CountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(_ctx);
}

/// <summary>
/// Get the raw KV cache view
/// </summary>
/// <returns></returns>
public ref LLamaKvCacheView GetView()
{
return ref _view;
}
}

partial class NativeApi
{
/// <summary>
@@ -66,9 +144,8 @@ partial class NativeApi
/// <summary>
/// Free a KV cache view. (use only for debugging purposes)
/// </summary>
/// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view);
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);

/// <summary>
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
@@ -76,7 +153,7 @@ partial class NativeApi
/// <param name="ctx"></param>
/// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view);
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);

/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)


Loading…
Cancel
Save