You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaKvCacheView.cs 8.7 kB

April 2024 Binary Update (#662) * Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. using System;
  2. using System.Runtime.InteropServices;
  3. namespace LLama.Native;
  4. /// <summary>
  5. /// A safe handle for a LLamaKvCacheView
  6. /// </summary>
  7. public sealed class LLamaKvCacheViewSafeHandle
  8. : SafeLLamaHandleBase
  9. {
  10. private readonly SafeLLamaContextHandle _ctx;
  11. private NativeLLamaKvCacheView _view;
  12. /// <summary>
  13. /// Number of KV cache cells. This will be the same as the context size.
  14. /// </summary>
  15. public int CellCount => GetNativeView().n_cells;
  16. /// <summary>
  17. /// Get the total number of tokens in the KV cache.
  18. ///
  19. /// For example, if there are two populated
  20. /// cells, the first with 1 sequence id in it and the second with 2 sequence
  21. /// ids then you'll have 3 tokens.
  22. /// </summary>
  23. public int TokenCount => GetNativeView().token_count;
  24. /// <summary>
  25. /// Maximum number of sequences visible for a cell. There may be more sequences than this
  26. /// in reality, this is simply the maximum number this view can see.
  27. /// </summary>
  28. public int MaxSequenceCount => GetNativeView().n_seq_max;
  29. /// <summary>
  30. /// Number of populated cache cells
  31. /// </summary>
  32. public int UsedCellCount => GetNativeView().used_cells;
  33. /// <summary>
  34. /// Maximum contiguous empty slots in the cache.
  35. /// </summary>
  36. public int MaxContiguous => GetNativeView().max_contiguous;
  37. /// <summary>
  38. /// Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
  39. /// </summary>
  40. public int MaxContiguousIdx => GetNativeView().max_contiguous;
  41. /// <summary>
  42. /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
  43. /// </summary>
  44. /// <param name="ctx"></param>
  45. /// <param name="view"></param>
  46. private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view)
  47. : base((IntPtr)1, true)
  48. {
  49. _ctx = ctx;
  50. _view = view;
  51. }
  52. /// <summary>
  53. /// Allocate a new KV cache view which can be used to inspect the KV cache
  54. /// </summary>
  55. /// <param name="ctx"></param>
  56. /// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
  57. /// <returns></returns>
  58. public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
  59. {
  60. // Allocate the view
  61. var view = llama_kv_cache_view_init(ctx, maxSequences);
  62. var handle = new LLamaKvCacheViewSafeHandle(ctx, view);
  63. // Update the view so it has valid data after allocation.
  64. handle.Update();
  65. return handle;
  66. }
  67. /// <inheritdoc />
  68. protected override bool ReleaseHandle()
  69. {
  70. llama_kv_cache_view_free(ref _view);
  71. SetHandle(IntPtr.Zero);
  72. return true;
  73. }
  74. /// <summary>
  75. /// Read the current KV cache state into this view.
  76. /// </summary>
  77. public void Update()
  78. {
  79. llama_kv_cache_view_update(_ctx, ref _view);
  80. }
  81. /// <summary>
  82. /// Get the raw KV cache view
  83. /// </summary>
  84. /// <returns></returns>
  85. private ref NativeLLamaKvCacheView GetNativeView()
  86. {
  87. if (IsClosed)
  88. throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed");
  89. return ref _view;
  90. }
  91. /// <summary>
  92. /// Get the cell at the given index
  93. /// </summary>
  94. /// <param name="index">The index of the cell [0, CellCount)</param>
  95. /// <returns>Data about the cell at the given index</returns>
  96. /// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
  97. public LLamaPos GetCell(int index)
  98. {
  99. var view = GetNativeView();
  100. if (index < 0)
  101. throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
  102. if (index >= view.n_cells)
  103. throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
  104. unsafe
  105. {
  106. return view.cells[index].pos;
  107. }
  108. }
  109. /// <summary>
  110. /// Get all of the sequences assigned to the cell at the given index. This will contain <see cref="MaxSequenceCount"/> entries
  111. /// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
  112. /// if necessary. Invalid sequences will be negative values.
  113. /// </summary>
  114. /// <param name="index">The index of the cell [0, CellCount)</param>
  115. /// <returns>A span containing the sequences assigned to this cell</returns>
  116. /// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
  117. public Span<LLamaSeqId> GetCellSequences(int index)
  118. {
  119. var view = GetNativeView();
  120. if (index < 0)
  121. throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
  122. if (index >= view.n_cells)
  123. throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");
  124. unsafe
  125. {
  126. return new Span<LLamaSeqId>(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max);
  127. }
  128. }
  129. #region native API
  130. /// <summary>
  131. /// Create an empty KV cache view. (use only for debugging purposes)
  132. /// </summary>
  133. /// <param name="ctx"></param>
  134. /// <param name="n_seq_max"></param>
  135. /// <returns></returns>
  136. [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
  137. private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
  138. /// <summary>
  139. /// Free a KV cache view. (use only for debugging purposes)
  140. /// </summary>
  141. [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
  142. private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view);
  143. /// <summary>
  144. /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
  145. /// </summary>
  146. /// <param name="ctx"></param>
  147. /// <param name="view"></param>
  148. [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
  149. private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
  150. /// <summary>
  151. /// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
  152. /// </summary>
  153. [StructLayout(LayoutKind.Sequential)]
  154. private struct NativeLLamaKvCacheViewCell
  155. {
  156. /// <summary>
  157. /// The position for this cell. Takes KV cache shifts into account.
  158. /// May be negative if the cell is not populated.
  159. /// </summary>
  160. public LLamaPos pos;
  161. }
  162. /// <summary>
  163. /// An updateable view of the KV cache (llama_kv_cache_view)
  164. /// </summary>
  165. [StructLayout(LayoutKind.Sequential)]
  166. private unsafe struct NativeLLamaKvCacheView
  167. {
  168. /// <summary>
  169. /// Number of KV cache cells. This will be the same as the context size.
  170. /// </summary>
  171. public int n_cells;
  172. /// <summary>
  173. /// Maximum number of sequences that can exist in a cell. It's not an error
  174. /// if there are more sequences in a cell than this value, however they will
  175. /// not be visible in the view cells_sequences.
  176. /// </summary>
  177. public int n_seq_max;
  178. /// <summary>
  179. /// Number of tokens in the cache. For example, if there are two populated
  180. /// cells, the first with 1 sequence id in it and the second with 2 sequence
  181. /// ids then you'll have 3 tokens.
  182. /// </summary>
  183. public int token_count;
  184. /// <summary>
  185. /// Number of populated cache cells.
  186. /// </summary>
  187. public int used_cells;
  188. /// <summary>
  189. /// Maximum contiguous empty slots in the cache.
  190. /// </summary>
  191. public int max_contiguous;
  192. /// <summary>
  193. /// Index to the start of the max_contiguous slot range. Can be negative
  194. /// when cache is full.
  195. /// </summary>
  196. public int max_contiguous_idx;
  197. /// <summary>
  198. /// Information for an individual cell.
  199. /// </summary>
  200. public NativeLLamaKvCacheViewCell* cells;
  201. /// <summary>
  202. /// The sequences for each cell. There will be n_seq_max items per cell.
  203. /// </summary>
  204. public LLamaSeqId* cells_sequences;
  205. }
  206. #endregion
  207. }