using System; using System.Collections.Generic; using LLama.Native; namespace LLama.Batched; /// /// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM) /// public sealed class Conversation : IDisposable { private ulong _requiredEpoch; private LLamaPos _end; private int _batchIndex; private bool _disposed; /// /// The executor which this conversation belongs to /// public BatchedExecutor Executor { get; } /// /// Unique ID for this conversation /// public LLamaSeqId ConversationId { get; } /// /// Total number of tokens in this conversation, cannot exceed the context length. /// public int TokenCount => _end.Value; /// /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation /// public bool IsDisposed => _disposed || Executor.IsDisposed; /// /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true. /// public bool RequiresInference => _requiredEpoch > Executor.Epoch; /// /// Indicates that this conversation should be sampled. /// public bool RequiresSampling => _requiredEpoch == Executor.Epoch; #region construction/destruction internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end) { ConversationId = id; Executor = batch; _end = end; } /// /// Finalizer for Conversation /// ~Conversation() { Dispose(); } /// /// End this conversation, freeing all resources used by it /// /// public void Dispose() { if (IsDisposed) return; _disposed = true; // Remove this conversation from the KV cache Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end); // Prevent finalizer from running GC.SuppressFinalize(this); } private void AssertNotDisposed() { if (Executor.IsDisposed) throw new ObjectDisposedException(nameof(BatchedExecutor)); if (IsDisposed) throw new ObjectDisposedException(nameof(Conversation)); } #endregion /// /// Create a copy of the current conversation /// /// The copy shares internal state, so consumes very little extra memory. /// /// public Conversation Fork() { AssertNotDisposed(); if (RequiresInference) throw new CannotForkWhileRequiresInferenceException(); // Create a new conversation which references the current position in this one var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) { _batchIndex = _batchIndex, _requiredEpoch = _requiredEpoch, }; // Assign tokens to the new sequence NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); return c; } #region sample /// /// Get the logits from this conversation, ready for sampling /// /// /// /// Thrown if this conversation was not prompted before the previous call to infer /// Thrown if Infer() must be called on the executor public ReadOnlySpan Sample() { AssertNotDisposed(); if (_requiredEpoch < Executor.Epoch) throw new CannotSampleRequiresPromptException(); if (_requiredEpoch > Executor.Epoch) throw new CannotSampleRequiresInferenceException(); return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); } #endregion #region prompt private void AssertCanBePrompted() { AssertNotDisposed(); if (RequiresInference) throw new AlreadyPromptedConversationException(); } /// /// Add tokens to this conversation /// /// /// public void Prompt(string input) { AssertCanBePrompted(); Prompt(Executor.Context.Tokenize(input)); } /// /// Add tokens to this conversation /// /// /// /// public void Prompt(IReadOnlyList tokens) { AssertCanBePrompted(); // Add the prompt to the batch for (var i = 0; i < tokens.Count; i++) _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1; } /// /// Add a single token to this conversation /// /// /// /// /// public void Prompt(LLamaToken token) { AssertCanBePrompted(); // Add this token as input _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true); // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1; } #endregion #region modify /// /// Directly modify the KV cache of this conversation /// /// /// Thrown if this method is called while == true public void Modify(ModifyKvCache modifier) { AssertNotDisposed(); if (RequiresInference) throw new CannotModifyWhileRequiresInferenceException(); // do whatever the modification is _end = modifier.Invoke(_end, new KvAccessor(this)); // Set the epoch down to zero, this ensures that this conversation // cannot be sampled until it is prompted again. _requiredEpoch = 0; } /// /// Provides direct access to the KV cache of a . /// See for how to use this. /// public readonly ref struct KvAccessor { private readonly Conversation _conversation; internal KvAccessor(Conversation conversation) { _conversation = conversation; } #region remove /// /// Removes all tokens that have positions in [start, end) /// /// Start position (inclusive) /// End position (exclusive) public void Remove(LLamaPos start, LLamaPos end) { _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); } /// /// Removes all tokens starting from the given position /// /// Start position (inclusive) /// Number of tokens public void Remove(LLamaPos start, int count) { if (count <= 0) return; var end = start.Value + count; _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); } #endregion #region shift /// /// Adds relative position "delta" to all tokens that have positions in [p0, p1). /// If the KV cache is RoPEd, the KV data is updated /// accordingly /// /// Start position (inclusive) /// End position (exclusive) /// Amount to add on to each token position public void Shift(LLamaPos start, LLamaPos end, int delta) { _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); } #endregion #region divide /// /// Integer division of the positions by factor of `d > 1`. /// If the KV cache is RoPEd, the KV data is updated accordingly. /// /// Start position (inclusive). If less than zero, it is clamped to zero. /// End position (exclusive). If less than zero, it is treated as "infinity". /// Amount to divide each position by. public void Divide(LLamaPos start, LLamaPos end, int divisor) { if (divisor <= 0) throw new ArgumentOutOfRangeException(nameof(divisor)); _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor); } #endregion } /// /// A function which can temporarily access the KV cache of a to modify it directly /// /// The current end token of this conversation /// An which allows direct access to modify the KV cache /// The new end token position public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); #endregion }