- Re-implmented `Rewind` as an extension method using `Modify` internally - Implemented `ShiftLeft`, which shifts everything over except for some starting tokens. This is the same as the `StatelessExecutor` out-of-context handling. - Starting batch at epoch 1, this ensures that conversations (starting at zero) are below the current epoch. It also means `0` can always be used as a value guaranteed to be below the current epoch.tags/v0.10.0
| @@ -45,9 +45,9 @@ public sealed class BatchedExecutor | |||||
| public BatchedExecutor(LLamaWeights model, IContextParams contextParams) | public BatchedExecutor(LLamaWeights model, IContextParams contextParams) | ||||
| { | { | ||||
| Model = model; | Model = model; | ||||
| Batch = new LLamaBatch(); | Batch = new LLamaBatch(); | ||||
| Context = model.CreateContext(contextParams); | Context = model.CreateContext(contextParams); | ||||
| Epoch = 1; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -90,39 +90,17 @@ public sealed class Conversation | |||||
| if (RequiresInference) | if (RequiresInference) | ||||
| throw new CannotForkWhileRequiresInference(); | throw new CannotForkWhileRequiresInference(); | ||||
| // Assign tokens to the new sequence | |||||
| var id2 = Executor.GetNextSequenceId(); | |||||
| NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, id2, 0, _end); | |||||
| // Create a new conversation which references the current position in this one | // Create a new conversation which references the current position in this one | ||||
| var c = new Conversation(Executor, id2, _end) | |||||
| var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | |||||
| { | { | ||||
| _batchIndex = _batchIndex, | _batchIndex = _batchIndex, | ||||
| _requiredEpoch = _requiredEpoch, | _requiredEpoch = _requiredEpoch, | ||||
| }; | }; | ||||
| return c; | |||||
| } | |||||
| /// <summary> | |||||
| /// Rewind this conversation back to an earlier state | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| /// <exception cref="CannotForkWhileRequiresInference"></exception> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than NTokens</exception> | |||||
| public void Rewind(int tokens) | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (tokens > TokenCount) | |||||
| throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); | |||||
| // Remove those tokens from KV | |||||
| Executor.Context.NativeHandle.KvCacheRemove(ConversationId, _end.Value - tokens, _end); | |||||
| // Assign tokens to the new sequence | |||||
| NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); | |||||
| // Adjust "end" marker back | |||||
| _end = _end.Value - tokens; | |||||
| return c; | |||||
| } | } | ||||
| #region sample | #region sample | ||||
| @@ -203,4 +181,89 @@ public sealed class Conversation | |||||
| _requiredEpoch = Executor.Epoch + 1; | _requiredEpoch = Executor.Epoch + 1; | ||||
| } | } | ||||
| #endregion | #endregion | ||||
| #region modify | |||||
| /// <summary> | |||||
| /// Directly modify the KV cache of this conversation | |||||
| /// </summary> | |||||
| /// <param name="modifier"></param> | |||||
| /// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||||
| public void Modify(ModifyKvCache modifier) | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (RequiresInference) | |||||
| throw new CannotModifyWhileRequiresInference(); | |||||
| // do whatever the modification is | |||||
| _end = modifier.Invoke(_end, new KvAccessor(this)); | |||||
| // Set the epoch down to zero, this ensures that this conversation | |||||
| // cannot be sampled until it is prompted again. | |||||
| _requiredEpoch = 0; | |||||
| } | |||||
| /// <summary> | |||||
| /// Provides direct access to the KV cache of a <see cref="Conversation"/>. | |||||
| /// See <see cref="Modify"/> for how to use this. | |||||
| /// </summary> | |||||
| public readonly ref struct KvAccessor | |||||
| { | |||||
| private readonly Conversation _conversation; | |||||
| internal KvAccessor(Conversation conversation) | |||||
| { | |||||
| _conversation = conversation; | |||||
| } | |||||
| #region remove | |||||
| /// <summary> | |||||
| /// Removes all tokens that have positions in [start, end) | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="end">End position (exclusive)</param> | |||||
| public void Remove(LLamaPos start, LLamaPos end) | |||||
| { | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||||
| } | |||||
| /// <summary> | |||||
| /// Removes all tokens starting from the given position | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="count">Number of tokens</param> | |||||
| public void Remove(LLamaPos start, int count) | |||||
| { | |||||
| if (count <= 0) | |||||
| return; | |||||
| var end = start.Value + count; | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||||
| } | |||||
| #endregion | |||||
| #region shift | |||||
| /// <summary> | |||||
| /// Adds relative position "delta" to all tokens that have positions in [p0, p1). | |||||
| /// If the KV cache is RoPEd, the KV data is updated | |||||
| /// accordingly | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="end">End position (exclusive)</param> | |||||
| /// <param name="delta">Amount to add on to each token position</param> | |||||
| public void Shift(LLamaPos start, LLamaPos end, int delta) | |||||
| { | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| /// <summary> | |||||
| /// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly | |||||
| /// </summary> | |||||
| /// <param name="end">The current end token of this conversation</param> | |||||
| /// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param> | |||||
| /// <returns>The new end token position</returns> | |||||
| public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); | |||||
| #endregion | |||||
| } | } | ||||
| @@ -0,0 +1,59 @@ | |||||
| using System; | |||||
| namespace LLama.Batched; | |||||
| /// <summary> | |||||
| /// Extension method for <see cref="Conversation"/> | |||||
| /// </summary> | |||||
| public static class ConversationExtensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end | |||||
| /// </summary> | |||||
| /// <param name="conversation">The conversation to rewind</param> | |||||
| /// <param name="tokens">The number of tokens to rewind</param> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than TokenCount</exception> | |||||
| public static void Rewind(this Conversation conversation, int tokens) | |||||
| { | |||||
| if (tokens > conversation.TokenCount) | |||||
| throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); | |||||
| conversation.Modify((end, kv) => | |||||
| { | |||||
| // Remove those tokens from KV | |||||
| kv.Remove(end.Value - tokens, tokens); | |||||
| // Return adjusted end position | |||||
| return end.Value - tokens; | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over. | |||||
| /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context | |||||
| /// gets full, keeping the prompt at the start intact. | |||||
| /// </summary> | |||||
| /// <param name="conversation">The conversation to rewind</param> | |||||
| /// <param name="count">How much to shift tokens over by</param> | |||||
| /// <param name="keep">The number of tokens at the start which should <b>not</b> be shifted</param> | |||||
| public static void ShiftLeft(this Conversation conversation, int count, int keep) | |||||
| { | |||||
| // Given a setup like this (shift=5, keep=3): | |||||
| // | |||||
| // AAABBBBBCCCCCCCCC... | |||||
| // | |||||
| // We want to remove all the B's, shift all the C's and leave all the A's untouched | |||||
| conversation.Modify((end, kv) => | |||||
| { | |||||
| // Remove the B's | |||||
| kv.Remove(keep, count); | |||||
| // Shift the C's | |||||
| kv.Shift(keep + count, end, -count); | |||||
| // Update total count | |||||
| return end.Value - count; | |||||
| }); | |||||
| } | |||||
| } | |||||
| @@ -57,7 +57,7 @@ public class CannotSampleRequiresPromptException | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when "Fork()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true | |||||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// </summary> | /// </summary> | ||||
| public class CannotForkWhileRequiresInference | public class CannotForkWhileRequiresInference | ||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| @@ -69,13 +69,13 @@ public class CannotForkWhileRequiresInference | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when "Rewind()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true | |||||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// </summary> | /// </summary> | ||||
| public class CannotRewindWhileRequiresInference | |||||
| public class CannotModifyWhileRequiresInference | |||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| { | { | ||||
| internal CannotRewindWhileRequiresInference() | |||||
| : base("Cannot `Rewind()` a conversation while RequiresInference is true") | |||||
| internal CannotModifyWhileRequiresInference() | |||||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| @@ -388,7 +388,7 @@ namespace LLama.Native | |||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); | |||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); | |||||
| /// <summary> | /// <summary> | ||||
| /// Integer division of the positions by factor of `d > 1` | /// Integer division of the positions by factor of `d > 1` | ||||
| @@ -369,7 +369,7 @@ namespace LLama.Native | |||||
| /// <param name="p0"></param> | /// <param name="p0"></param> | ||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) | |||||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) | |||||
| { | { | ||||
| NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | ||||
| } | } | ||||