using System;
using System.Collections.Generic;
using LLama.Native;
namespace LLama.Batched;
///
/// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
///
public sealed class Conversation
: IDisposable
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchIndex;
private bool _disposed;
///
/// The executor which this conversation belongs to
///
public BatchedExecutor Executor { get; }
///
/// Unique ID for this conversation
///
public LLamaSeqId ConversationId { get; }
///
/// Total number of tokens in this conversation, cannot exceed the context length.
///
public int TokenCount => _end.Value;
///
/// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
///
public bool IsDisposed => _disposed || Executor.IsDisposed;
///
/// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
///
public bool RequiresInference => _requiredEpoch > Executor.Epoch;
///
/// Indicates that this conversation should be sampled.
///
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
#region construction/destruction
internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
{
ConversationId = id;
Executor = batch;
_end = end;
}
///
/// Finalizer for Conversation
///
~Conversation()
{
Dispose();
}
///
/// End this conversation, freeing all resources used by it
///
///
public void Dispose()
{
if (IsDisposed)
return;
_disposed = true;
// Remove this conversation from the KV cache
Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);
// Prevent finalizer from running
GC.SuppressFinalize(this);
}
private void AssertNotDisposed()
{
if (Executor.IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));
if (IsDisposed)
throw new ObjectDisposedException(nameof(Conversation));
}
#endregion
///
/// Create a copy of the current conversation
///
/// The copy shares internal state, so consumes very little extra memory.
///
///
public Conversation Fork()
{
AssertNotDisposed();
if (RequiresInference)
throw new CannotForkWhileRequiresInferenceException();
// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
{
_batchIndex = _batchIndex,
_requiredEpoch = _requiredEpoch,
};
// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
return c;
}
#region sample
///
/// Get the logits from this conversation, ready for sampling
///
///
///
/// Thrown if this conversation was not prompted before the previous call to infer
/// Thrown if Infer() must be called on the executor
public ReadOnlySpan Sample()
{
AssertNotDisposed();
if (_requiredEpoch < Executor.Epoch)
throw new CannotSampleRequiresPromptException();
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();
return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
}
#endregion
#region prompt
private void AssertCanBePrompted()
{
AssertNotDisposed();
if (RequiresInference)
throw new AlreadyPromptedConversationException();
}
///
/// Add tokens to this conversation
///
///
///
public void Prompt(string input)
{
AssertCanBePrompted();
Prompt(Executor.Context.Tokenize(input));
}
///
/// Add tokens to this conversation
///
///
///
///
public void Prompt(IReadOnlyList tokens)
{
AssertCanBePrompted();
// Add the prompt to the batch
for (var i = 0; i < tokens.Count; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
}
///
/// Add a single token to this conversation
///
///
///
///
///
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();
// Add this token as input
_batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
}
#endregion
#region modify
///
/// Directly modify the KV cache of this conversation
///
///
/// Thrown if this method is called while == true
public void Modify(ModifyKvCache modifier)
{
AssertNotDisposed();
if (RequiresInference)
throw new CannotModifyWhileRequiresInferenceException();
// do whatever the modification is
_end = modifier.Invoke(_end, new KvAccessor(this));
// Set the epoch down to zero, this ensures that this conversation
// cannot be sampled until it is prompted again.
_requiredEpoch = 0;
}
///
/// Provides direct access to the KV cache of a .
/// See for how to use this.
///
public readonly ref struct KvAccessor
{
private readonly Conversation _conversation;
internal KvAccessor(Conversation conversation)
{
_conversation = conversation;
}
#region remove
///
/// Removes all tokens that have positions in [start, end)
///
/// Start position (inclusive)
/// End position (exclusive)
public void Remove(LLamaPos start, LLamaPos end)
{
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}
///
/// Removes all tokens starting from the given position
///
/// Start position (inclusive)
/// Number of tokens
public void Remove(LLamaPos start, int count)
{
if (count <= 0)
return;
var end = start.Value + count;
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}
#endregion
#region shift
///
/// Adds relative position "delta" to all tokens that have positions in [p0, p1).
/// If the KV cache is RoPEd, the KV data is updated
/// accordingly
///
/// Start position (inclusive)
/// End position (exclusive)
/// Amount to add on to each token position
public void Shift(LLamaPos start, LLamaPos end, int delta)
{
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
}
#endregion
#region divide
///
/// Integer division of the positions by factor of `d > 1`.
/// If the KV cache is RoPEd, the KV data is updated accordingly.
///
/// Start position (inclusive). If less than zero, it is clamped to zero.
/// End position (exclusive). If less than zero, it is treated as "infinity".
/// Amount to divide each position by.
public void Divide(LLamaPos start, LLamaPos end, int divisor)
{
if (divisor <= 0)
throw new ArgumentOutOfRangeException(nameof(divisor));
_conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
}
#endregion
}
///
/// A function which can temporarily access the KV cache of a to modify it directly
///
/// The current end token of this conversation
/// An which allows direct access to modify the KV cache
/// The new end token position
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion
}