From bab6b65b61fc86c9462dff71c95139ea2dfbe416 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 13 Dec 2023 19:35:21 +0000 Subject: [PATCH] Added a safe handle for LLamaKvCacheView --- LLama.Examples/Program.cs | 6 ++- LLama/Native/LLamaKvCacheView.cs | 87 ++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 6 deletions(-) diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 9feb6202..85ec34f2 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -7,7 +7,11 @@ Console.WriteLine(" __ __ ____ _ Console.WriteLine("======================================================================================================"); -NativeLibraryConfig.Instance.WithCuda().WithLogs(); +NativeLibraryConfig + .Instance + .WithCuda() + .WithLogs() + .WithAvx(NativeLibraryConfig.AvxLevel.Avx512); NativeApi.llama_empty_call(); Console.WriteLine(); diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 395d447c..5dad1fc3 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System; +using System.Runtime.InteropServices; namespace LLama.Native; @@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell /// /// An updateable view of the KV cache (llama_kv_cache_view) /// -//todo: rewrite to safe handle? [StructLayout(LayoutKind.Sequential)] public unsafe struct LLamaKvCacheView { @@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView LLamaSeqId* cells_sequences; } +/// +/// A safe handle for a LLamaKvCacheView +/// +public class LLamaKvCacheViewSafeHandle + : SafeLLamaHandleBase +{ + private readonly SafeLLamaContextHandle _ctx; + private LLamaKvCacheView _view; + + /// + /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed + /// + /// + /// + public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view) + : base(IntPtr.MaxValue, true) + { + _ctx = ctx; + _view = view; + } + + /// + /// Allocate a new llama_kv_cache_view_free + /// + /// + /// The maximum number of sequences visible in this view per cell + /// + public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences) + { + var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences); + return new LLamaKvCacheViewSafeHandle(ctx, result); + } + + /// + protected override bool ReleaseHandle() + { + NativeApi.llama_kv_cache_view_free(ref _view); + SetHandle(IntPtr.Zero); + + return true; + } + + /// + /// Update this view + /// + public void Update() + { + NativeApi.llama_kv_cache_view_update(_ctx, ref _view); + } + + /// + /// Count the number of used cells in the KV cache + /// + /// + public int CountCells() + { + return NativeApi.llama_get_kv_cache_used_cells(_ctx); + } + + /// + /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times + /// + /// + public int CountTokens() + { + return NativeApi.llama_get_kv_cache_token_count(_ctx); + } + + /// + /// Get the raw KV cache view + /// + /// + public ref LLamaKvCacheView GetView() + { + return ref _view; + } +} + partial class NativeApi { /// @@ -66,9 +144,8 @@ partial class NativeApi /// /// Free a KV cache view. (use only for debugging purposes) /// - /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view); + public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view); /// /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) @@ -76,7 +153,7 @@ partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view); + public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view); /// /// Returns the number of tokens in the KV cache (slow, use only for debug)