diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 2fc3eabf..2009a687 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -45,9 +45,9 @@ public sealed class BatchedExecutor public BatchedExecutor(LLamaWeights model, IContextParams contextParams) { Model = model; - Batch = new LLamaBatch(); Context = model.CreateContext(contextParams); + Epoch = 1; } /// diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 3dcdec54..a699bb66 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -90,39 +90,17 @@ public sealed class Conversation if (RequiresInference) throw new CannotForkWhileRequiresInference(); - // Assign tokens to the new sequence - var id2 = Executor.GetNextSequenceId(); - NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, id2, 0, _end); - // Create a new conversation which references the current position in this one - var c = new Conversation(Executor, id2, _end) + var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) { _batchIndex = _batchIndex, _requiredEpoch = _requiredEpoch, }; - return c; - } - - /// - /// Rewind this conversation back to an earlier state - /// - /// - /// - /// - /// Thrown if `tokens` parameter is larger than NTokens - public void Rewind(int tokens) - { - AssertNotDisposed(); - - if (tokens > TokenCount) - throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); - - // Remove those tokens from KV - Executor.Context.NativeHandle.KvCacheRemove(ConversationId, _end.Value - tokens, _end); + // Assign tokens to the new sequence + NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); - // Adjust "end" marker back - _end = _end.Value - tokens; + return c; } #region sample @@ -203,4 +181,89 @@ public sealed class Conversation _requiredEpoch = Executor.Epoch + 1; } #endregion + + #region modify + /// + /// Directly modify the KV cache of this conversation + /// + /// + /// Thrown if this method is called while == true + public void Modify(ModifyKvCache modifier) + { + AssertNotDisposed(); + + if (RequiresInference) + throw new CannotModifyWhileRequiresInference(); + + // do whatever the modification is + _end = modifier.Invoke(_end, new KvAccessor(this)); + + // Set the epoch down to zero, this ensures that this conversation + // cannot be sampled until it is prompted again. + _requiredEpoch = 0; + } + + /// + /// Provides direct access to the KV cache of a . + /// See for how to use this. + /// + public readonly ref struct KvAccessor + { + private readonly Conversation _conversation; + + internal KvAccessor(Conversation conversation) + { + _conversation = conversation; + } + + #region remove + /// + /// Removes all tokens that have positions in [start, end) + /// + /// Start position (inclusive) + /// End position (exclusive) + public void Remove(LLamaPos start, LLamaPos end) + { + _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); + } + + /// + /// Removes all tokens starting from the given position + /// + /// Start position (inclusive) + /// Number of tokens + public void Remove(LLamaPos start, int count) + { + if (count <= 0) + return; + + var end = start.Value + count; + _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); + } + #endregion + + #region shift + /// + /// Adds relative position "delta" to all tokens that have positions in [p0, p1). + /// If the KV cache is RoPEd, the KV data is updated + /// accordingly + /// + /// Start position (inclusive) + /// End position (exclusive) + /// Amount to add on to each token position + public void Shift(LLamaPos start, LLamaPos end, int delta) + { + _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); + } + #endregion + } + + /// + /// A function which can temporarily access the KV cache of a to modify it directly + /// + /// The current end token of this conversation + /// An which allows direct access to modify the KV cache + /// The new end token position + public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); + #endregion } \ No newline at end of file diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs new file mode 100644 index 00000000..5fca5e94 --- /dev/null +++ b/LLama/Batched/ConversationExtensions.cs @@ -0,0 +1,59 @@ +using System; + +namespace LLama.Batched; + +/// +/// Extension method for +/// +public static class ConversationExtensions +{ + /// + /// Rewind a back to an earlier state by removing tokens from the end + /// + /// The conversation to rewind + /// The number of tokens to rewind + /// Thrown if `tokens` parameter is larger than TokenCount + public static void Rewind(this Conversation conversation, int tokens) + { + if (tokens > conversation.TokenCount) + throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); + + conversation.Modify((end, kv) => + { + // Remove those tokens from KV + kv.Remove(end.Value - tokens, tokens); + + // Return adjusted end position + return end.Value - tokens; + }); + } + + /// + /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over. + /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context + /// gets full, keeping the prompt at the start intact. + /// + /// The conversation to rewind + /// How much to shift tokens over by + /// The number of tokens at the start which should not be shifted + public static void ShiftLeft(this Conversation conversation, int count, int keep) + { + // Given a setup like this (shift=5, keep=3): + // + // AAABBBBBCCCCCCCCC... + // + // We want to remove all the B's, shift all the C's and leave all the A's untouched + + conversation.Modify((end, kv) => + { + // Remove the B's + kv.Remove(keep, count); + + // Shift the C's + kv.Shift(keep + count, end, -count); + + // Update total count + return end.Value - count; + }); + } +} \ No newline at end of file diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs index 7d847bfe..1feb270c 100644 --- a/LLama/Batched/Exceptions.cs +++ b/LLama/Batched/Exceptions.cs @@ -57,7 +57,7 @@ public class CannotSampleRequiresPromptException } /// -/// This exception is thrown when "Fork()" is called on a with = true +/// This exception is thrown when is called when = true /// public class CannotForkWhileRequiresInference : ExperimentalBatchedExecutorException @@ -69,13 +69,13 @@ public class CannotForkWhileRequiresInference } /// -/// This exception is thrown when "Rewind()" is called on a with = true +/// This exception is thrown when is called when = true /// -public class CannotRewindWhileRequiresInference +public class CannotModifyWhileRequiresInference : ExperimentalBatchedExecutorException { - internal CannotRewindWhileRequiresInference() - : base("Cannot `Rewind()` a conversation while RequiresInference is true") + internal CannotModifyWhileRequiresInference() + : base("Cannot `Modify()` a conversation while RequiresInference is true") { } } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index c953cb23..578cad40 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -388,7 +388,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); + public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); /// /// Integer division of the positions by factor of `d > 1` diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 91e82c85..da53491d 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -369,7 +369,7 @@ namespace LLama.Native /// /// /// - public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) + public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) { NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); }