- Re-implmented `Rewind` as an extension method using `Modify` internally - Implemented `ShiftLeft`, which shifts everything over except for some starting tokens. This is the same as the `StatelessExecutor` out-of-context handling. - Starting batch at epoch 1, this ensures that conversations (starting at zero) are below the current epoch. It also means `0` can always be used as a value guaranteed to be below the current epoch.tags/v0.10.0
| @@ -45,9 +45,9 @@ public sealed class BatchedExecutor | |||
| public BatchedExecutor(LLamaWeights model, IContextParams contextParams) | |||
| { | |||
| Model = model; | |||
| Batch = new LLamaBatch(); | |||
| Context = model.CreateContext(contextParams); | |||
| Epoch = 1; | |||
| } | |||
| /// <summary> | |||
| @@ -90,39 +90,17 @@ public sealed class Conversation | |||
| if (RequiresInference) | |||
| throw new CannotForkWhileRequiresInference(); | |||
| // Assign tokens to the new sequence | |||
| var id2 = Executor.GetNextSequenceId(); | |||
| NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, id2, 0, _end); | |||
| // Create a new conversation which references the current position in this one | |||
| var c = new Conversation(Executor, id2, _end) | |||
| var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | |||
| { | |||
| _batchIndex = _batchIndex, | |||
| _requiredEpoch = _requiredEpoch, | |||
| }; | |||
| return c; | |||
| } | |||
| /// <summary> | |||
| /// Rewind this conversation back to an earlier state | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| /// <exception cref="ObjectDisposedException"></exception> | |||
| /// <exception cref="CannotForkWhileRequiresInference"></exception> | |||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than NTokens</exception> | |||
| public void Rewind(int tokens) | |||
| { | |||
| AssertNotDisposed(); | |||
| if (tokens > TokenCount) | |||
| throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); | |||
| // Remove those tokens from KV | |||
| Executor.Context.NativeHandle.KvCacheRemove(ConversationId, _end.Value - tokens, _end); | |||
| // Assign tokens to the new sequence | |||
| NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); | |||
| // Adjust "end" marker back | |||
| _end = _end.Value - tokens; | |||
| return c; | |||
| } | |||
| #region sample | |||
| @@ -203,4 +181,89 @@ public sealed class Conversation | |||
| _requiredEpoch = Executor.Epoch + 1; | |||
| } | |||
| #endregion | |||
| #region modify | |||
| /// <summary> | |||
| /// Directly modify the KV cache of this conversation | |||
| /// </summary> | |||
| /// <param name="modifier"></param> | |||
| /// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||
| public void Modify(ModifyKvCache modifier) | |||
| { | |||
| AssertNotDisposed(); | |||
| if (RequiresInference) | |||
| throw new CannotModifyWhileRequiresInference(); | |||
| // do whatever the modification is | |||
| _end = modifier.Invoke(_end, new KvAccessor(this)); | |||
| // Set the epoch down to zero, this ensures that this conversation | |||
| // cannot be sampled until it is prompted again. | |||
| _requiredEpoch = 0; | |||
| } | |||
| /// <summary> | |||
| /// Provides direct access to the KV cache of a <see cref="Conversation"/>. | |||
| /// See <see cref="Modify"/> for how to use this. | |||
| /// </summary> | |||
| public readonly ref struct KvAccessor | |||
| { | |||
| private readonly Conversation _conversation; | |||
| internal KvAccessor(Conversation conversation) | |||
| { | |||
| _conversation = conversation; | |||
| } | |||
| #region remove | |||
| /// <summary> | |||
| /// Removes all tokens that have positions in [start, end) | |||
| /// </summary> | |||
| /// <param name="start">Start position (inclusive)</param> | |||
| /// <param name="end">End position (exclusive)</param> | |||
| public void Remove(LLamaPos start, LLamaPos end) | |||
| { | |||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||
| } | |||
| /// <summary> | |||
| /// Removes all tokens starting from the given position | |||
| /// </summary> | |||
| /// <param name="start">Start position (inclusive)</param> | |||
| /// <param name="count">Number of tokens</param> | |||
| public void Remove(LLamaPos start, int count) | |||
| { | |||
| if (count <= 0) | |||
| return; | |||
| var end = start.Value + count; | |||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||
| } | |||
| #endregion | |||
| #region shift | |||
| /// <summary> | |||
| /// Adds relative position "delta" to all tokens that have positions in [p0, p1). | |||
| /// If the KV cache is RoPEd, the KV data is updated | |||
| /// accordingly | |||
| /// </summary> | |||
| /// <param name="start">Start position (inclusive)</param> | |||
| /// <param name="end">End position (exclusive)</param> | |||
| /// <param name="delta">Amount to add on to each token position</param> | |||
| public void Shift(LLamaPos start, LLamaPos end, int delta) | |||
| { | |||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); | |||
| } | |||
| #endregion | |||
| } | |||
| /// <summary> | |||
| /// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly | |||
| /// </summary> | |||
| /// <param name="end">The current end token of this conversation</param> | |||
| /// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param> | |||
| /// <returns>The new end token position</returns> | |||
| public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); | |||
| #endregion | |||
| } | |||
| @@ -0,0 +1,59 @@ | |||
| using System; | |||
| namespace LLama.Batched; | |||
| /// <summary> | |||
| /// Extension method for <see cref="Conversation"/> | |||
| /// </summary> | |||
| public static class ConversationExtensions | |||
| { | |||
| /// <summary> | |||
| /// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end | |||
| /// </summary> | |||
| /// <param name="conversation">The conversation to rewind</param> | |||
| /// <param name="tokens">The number of tokens to rewind</param> | |||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than TokenCount</exception> | |||
| public static void Rewind(this Conversation conversation, int tokens) | |||
| { | |||
| if (tokens > conversation.TokenCount) | |||
| throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); | |||
| conversation.Modify((end, kv) => | |||
| { | |||
| // Remove those tokens from KV | |||
| kv.Remove(end.Value - tokens, tokens); | |||
| // Return adjusted end position | |||
| return end.Value - tokens; | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over. | |||
| /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context | |||
| /// gets full, keeping the prompt at the start intact. | |||
| /// </summary> | |||
| /// <param name="conversation">The conversation to rewind</param> | |||
| /// <param name="count">How much to shift tokens over by</param> | |||
| /// <param name="keep">The number of tokens at the start which should <b>not</b> be shifted</param> | |||
| public static void ShiftLeft(this Conversation conversation, int count, int keep) | |||
| { | |||
| // Given a setup like this (shift=5, keep=3): | |||
| // | |||
| // AAABBBBBCCCCCCCCC... | |||
| // | |||
| // We want to remove all the B's, shift all the C's and leave all the A's untouched | |||
| conversation.Modify((end, kv) => | |||
| { | |||
| // Remove the B's | |||
| kv.Remove(keep, count); | |||
| // Shift the C's | |||
| kv.Shift(keep + count, end, -count); | |||
| // Update total count | |||
| return end.Value - count; | |||
| }); | |||
| } | |||
| } | |||
| @@ -57,7 +57,7 @@ public class CannotSampleRequiresPromptException | |||
| } | |||
| /// <summary> | |||
| /// This exception is thrown when "Fork()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true | |||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||
| /// </summary> | |||
| public class CannotForkWhileRequiresInference | |||
| : ExperimentalBatchedExecutorException | |||
| @@ -69,13 +69,13 @@ public class CannotForkWhileRequiresInference | |||
| } | |||
| /// <summary> | |||
| /// This exception is thrown when "Rewind()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true | |||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||
| /// </summary> | |||
| public class CannotRewindWhileRequiresInference | |||
| public class CannotModifyWhileRequiresInference | |||
| : ExperimentalBatchedExecutorException | |||
| { | |||
| internal CannotRewindWhileRequiresInference() | |||
| : base("Cannot `Rewind()` a conversation while RequiresInference is true") | |||
| internal CannotModifyWhileRequiresInference() | |||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||
| { | |||
| } | |||
| } | |||
| @@ -388,7 +388,7 @@ namespace LLama.Native | |||
| /// <param name="p1"></param> | |||
| /// <param name="delta"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); | |||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); | |||
| /// <summary> | |||
| /// Integer division of the positions by factor of `d > 1` | |||
| @@ -369,7 +369,7 @@ namespace LLama.Native | |||
| /// <param name="p0"></param> | |||
| /// <param name="p1"></param> | |||
| /// <param name="delta"></param> | |||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) | |||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) | |||
| { | |||
| NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | |||
| } | |||