diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index b15d445e..5d637c8e 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -55,14 +55,6 @@ public sealed class BatchedExecutor Epoch = 1; } - /// - /// Finalizer for BatchedExecutor - /// - ~BatchedExecutor() - { - Dispose(); - } - /// /// Start a new with the given prompt /// @@ -89,7 +81,7 @@ public sealed class BatchedExecutor if (IsDisposed) throw new ObjectDisposedException(nameof(BatchedExecutor)); - return new Conversation(this, GetNextSequenceId(), 0); + return new Conversation(this, GetNextSequenceId()); } /// @@ -123,8 +115,6 @@ public sealed class BatchedExecutor return; IsDisposed = true; - GC.SuppressFinalize(this); - Context.Dispose(); } diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 29759cf9..cfdb0a1f 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.InteropServices; using LLama.Native; namespace LLama.Batched; @@ -14,6 +16,7 @@ public sealed class Conversation private LLamaPos _end; private int _batchIndex; private bool _disposed; + private bool _forked; /// /// The executor which this conversation belongs to @@ -46,12 +49,10 @@ public sealed class Conversation public bool RequiresSampling => _requiredEpoch == Executor.Epoch; #region construction/destruction - internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end) + internal Conversation(BatchedExecutor batch, LLamaSeqId id) { ConversationId = id; Executor = batch; - - _end = end; } /// @@ -98,16 +99,24 @@ public sealed class Conversation { AssertNotDisposed(); - if (RequiresInference) - throw new CannotForkWhileRequiresInferenceException(); - // Create a new conversation which references the current position in this one - var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) + var c = new Conversation(Executor, Executor.GetNextSequenceId()) { - _batchIndex = _batchIndex, + // Because these values are copied to the forked conversation it means that it will share the exact same output + // logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those + // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures + // they both copy the logits before the next sampling run, to fix this issue. _requiredEpoch = _requiredEpoch, + _batchIndex = _batchIndex, + _forked = true, + + _end = _end, }; + // Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked + // conversation doesn't share logits with this one. + _forked = true; + // Assign tokens to the new sequence NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); @@ -131,7 +140,14 @@ public sealed class Conversation if (_requiredEpoch > Executor.Epoch) throw new CannotSampleRequiresInferenceException(); - return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + + // If necessary copy the span, to protect it from modification. This is only done when + // this conversation has been forked in this epoch. + if (_forked) + span = span.ToArray(); + + return span; } #endregion @@ -162,20 +178,56 @@ public sealed class Conversation /// /// /// - public void Prompt(IReadOnlyList tokens) + /// + public void Prompt(List tokens) + { + AssertCanBePrompted(); + +#if NET6_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + Prompt(span); +#else + // Borrow an array and copy tokens into it + var arr = ArrayPool.Shared.Rent(tokens.Count); + try + { + for (var i = 0; i < tokens.Count; i++) + arr[i] = tokens[i]; + + Prompt(arr.AsSpan()); + } + finally + { + ArrayPool.Shared.Return(arr); + } +#endif + } + + /// + /// Add tokens to this conversation + /// + /// + /// + /// + /// + public void Prompt(ReadOnlySpan tokens) { AssertCanBePrompted(); // No point doing anything if there is no actual prompt! - if (tokens.Count == 0) + if (tokens.Length == 0) return; // Add the prompt to the batch - for (var i = 0; i < tokens.Count; i++) - _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); + for (var i = 0; i < tokens.Length; i++) + _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1; + + // Unset the forked flag. Since this conversation has just been prompted it's no longer + // sharing anything with any other conversations. + _forked = false; } /// @@ -184,16 +236,16 @@ public sealed class Conversation /// /// /// - /// + /// public void Prompt(LLamaToken token) { AssertCanBePrompted(); - // Add this token as input - _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true); - - // Mark this conversation as needing inference/sampling - _requiredEpoch = Executor.Epoch + 1; + unsafe + { + Span span = stackalloc LLamaToken[1] { token }; + Prompt(span); + } } #endregion