Browse Source

- Added a `Modify` method to `Conversation`. This grants **temporary** access to directly modify the KV cache.

- Re-implmented `Rewind` as an extension method using `Modify` internally
 - Implemented `ShiftLeft`, which shifts everything over except for some starting tokens. This is the same as the `StatelessExecutor` out-of-context handling.
 - Starting batch at epoch 1, this ensures that conversations (starting at zero) are below the current epoch. It also means `0` can always be used as a value guaranteed to be below the current epoch.
tags/v0.10.0
Martin Evans 1 year ago
parent
commit
949861a581
6 changed files with 156 additions and 34 deletions
  1. +1
    -1
      LLama/Batched/BatchedExecutor.cs
  2. +89
    -26
      LLama/Batched/Conversation.cs
  3. +59
    -0
      LLama/Batched/ConversationExtensions.cs
  4. +5
    -5
      LLama/Batched/Exceptions.cs
  5. +1
    -1
      LLama/Native/NativeApi.cs
  6. +1
    -1
      LLama/Native/SafeLLamaContextHandle.cs

+ 1
- 1
LLama/Batched/BatchedExecutor.cs View File

@@ -45,9 +45,9 @@ public sealed class BatchedExecutor
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
{
Model = model;

Batch = new LLamaBatch();
Context = model.CreateContext(contextParams);
Epoch = 1;
}

/// <summary>


+ 89
- 26
LLama/Batched/Conversation.cs View File

@@ -90,39 +90,17 @@ public sealed class Conversation
if (RequiresInference)
throw new CannotForkWhileRequiresInference();

// Assign tokens to the new sequence
var id2 = Executor.GetNextSequenceId();
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, id2, 0, _end);

// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, id2, _end)
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
{
_batchIndex = _batchIndex,
_requiredEpoch = _requiredEpoch,
};

return c;
}

/// <summary>
/// Rewind this conversation back to an earlier state
/// </summary>
/// <param name="tokens"></param>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="CannotForkWhileRequiresInference"></exception>
/// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than NTokens</exception>
public void Rewind(int tokens)
{
AssertNotDisposed();

if (tokens > TokenCount)
throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens");

// Remove those tokens from KV
Executor.Context.NativeHandle.KvCacheRemove(ConversationId, _end.Value - tokens, _end);
// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);

// Adjust "end" marker back
_end = _end.Value - tokens;
return c;
}

#region sample
@@ -203,4 +181,89 @@ public sealed class Conversation
_requiredEpoch = Executor.Epoch + 1;
}
#endregion

#region modify
/// <summary>
/// Directly modify the KV cache of this conversation
/// </summary>
/// <param name="modifier"></param>
/// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
public void Modify(ModifyKvCache modifier)
{
AssertNotDisposed();

if (RequiresInference)
throw new CannotModifyWhileRequiresInference();

// do whatever the modification is
_end = modifier.Invoke(_end, new KvAccessor(this));

// Set the epoch down to zero, this ensures that this conversation
// cannot be sampled until it is prompted again.
_requiredEpoch = 0;
}

/// <summary>
/// Provides direct access to the KV cache of a <see cref="Conversation"/>.
/// See <see cref="Modify"/> for how to use this.
/// </summary>
public readonly ref struct KvAccessor
{
private readonly Conversation _conversation;

internal KvAccessor(Conversation conversation)
{
_conversation = conversation;
}

#region remove
/// <summary>
/// Removes all tokens that have positions in [start, end)
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
public void Remove(LLamaPos start, LLamaPos end)
{
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}

/// <summary>
/// Removes all tokens starting from the given position
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="count">Number of tokens</param>
public void Remove(LLamaPos start, int count)
{
if (count <= 0)
return;

var end = start.Value + count;
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}
#endregion

#region shift
/// <summary>
/// Adds relative position "delta" to all tokens that have positions in [p0, p1).
/// If the KV cache is RoPEd, the KV data is updated
/// accordingly
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
/// <param name="delta">Amount to add on to each token position</param>
public void Shift(LLamaPos start, LLamaPos end, int delta)
{
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
}
#endregion
}

/// <summary>
/// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly
/// </summary>
/// <param name="end">The current end token of this conversation</param>
/// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param>
/// <returns>The new end token position</returns>
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion
}

+ 59
- 0
LLama/Batched/ConversationExtensions.cs View File

@@ -0,0 +1,59 @@
using System;

namespace LLama.Batched;

/// <summary>
/// Extension method for <see cref="Conversation"/>
/// </summary>
public static class ConversationExtensions
{
/// <summary>
/// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end
/// </summary>
/// <param name="conversation">The conversation to rewind</param>
/// <param name="tokens">The number of tokens to rewind</param>
/// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than TokenCount</exception>
public static void Rewind(this Conversation conversation, int tokens)
{
if (tokens > conversation.TokenCount)
throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens");

conversation.Modify((end, kv) =>
{
// Remove those tokens from KV
kv.Remove(end.Value - tokens, tokens);

// Return adjusted end position
return end.Value - tokens;
});
}

/// <summary>
/// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over.
/// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context
/// gets full, keeping the prompt at the start intact.
/// </summary>
/// <param name="conversation">The conversation to rewind</param>
/// <param name="count">How much to shift tokens over by</param>
/// <param name="keep">The number of tokens at the start which should <b>not</b> be shifted</param>
public static void ShiftLeft(this Conversation conversation, int count, int keep)
{
// Given a setup like this (shift=5, keep=3):
//
// AAABBBBBCCCCCCCCC...
//
// We want to remove all the B's, shift all the C's and leave all the A's untouched

conversation.Modify((end, kv) =>
{
// Remove the B's
kv.Remove(keep, count);

// Shift the C's
kv.Shift(keep + count, end, -count);

// Update total count
return end.Value - count;
});
}
}

+ 5
- 5
LLama/Batched/Exceptions.cs View File

@@ -57,7 +57,7 @@ public class CannotSampleRequiresPromptException
}

/// <summary>
/// This exception is thrown when "Fork()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotForkWhileRequiresInference
: ExperimentalBatchedExecutorException
@@ -69,13 +69,13 @@ public class CannotForkWhileRequiresInference
}

/// <summary>
/// This exception is thrown when "Rewind()" is called on a <see cref="Conversation"/> with <see cref="Conversation.RequiresInference"/> = true
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotRewindWhileRequiresInference
public class CannotModifyWhileRequiresInference
: ExperimentalBatchedExecutorException
{
internal CannotRewindWhileRequiresInference()
: base("Cannot `Rewind()` a conversation while RequiresInference is true")
internal CannotModifyWhileRequiresInference()
: base("Cannot `Modify()` a conversation while RequiresInference is true")
{
}
}

+ 1
- 1
LLama/Native/NativeApi.cs View File

@@ -388,7 +388,7 @@ namespace LLama.Native
/// <param name="p1"></param>
/// <param name="delta"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);

/// <summary>
/// Integer division of the positions by factor of `d > 1`


+ 1
- 1
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -369,7 +369,7 @@ namespace LLama.Native
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta)
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}


Loading…
Cancel
Save