From 268f3a6b0775049cd38c4da858ae6ca566a7ce1c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 20 Mar 2024 16:36:01 +0000 Subject: [PATCH] BatchedExecutor Fixed Forking (#621) * Previously when a conversation was forked this would result in both the parent and the child sharing exactly the same logits. Since sampling is allowed to modify logits this could lead to issues in sampling (e.g. one conversation is sampled and overwrites logits to be all zero, second conversation is sampled and generates nonsense). Fixed this by setting a "forked" flag, logits are copied if this flag is set. Flag is cleared next time the conversation is prompted so this extra copying only happens once after a fork occurs. * Removed finalizer from `BatchedExecutor`. This class does not directly own any unmanaged resources so it is not necessary. --- LLama/Batched/BatchedExecutor.cs | 12 +---- LLama/Batched/Conversation.cs | 90 +++++++++++++++++++++++++------- 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index b15d445e..5d637c8e 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -55,14 +55,6 @@ public sealed class BatchedExecutor Epoch = 1; } - /// - /// Finalizer for BatchedExecutor - /// - ~BatchedExecutor() - { - Dispose(); - } - /// /// Start a new with the given prompt /// @@ -89,7 +81,7 @@ public sealed class BatchedExecutor if (IsDisposed) throw new ObjectDisposedException(nameof(BatchedExecutor)); - return new Conversation(this, GetNextSequenceId(), 0); + return new Conversation(this, GetNextSequenceId()); } /// @@ -123,8 +115,6 @@ public sealed class BatchedExecutor return; IsDisposed = true; - GC.SuppressFinalize(this); - Context.Dispose(); } diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 29759cf9..cfdb0a1f 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.InteropServices; using LLama.Native; namespace LLama.Batched; @@ -14,6 +16,7 @@ public sealed class Conversation private LLamaPos _end; private int _batchIndex; private bool _disposed; + private bool _forked; /// /// The executor which this conversation belongs to @@ -46,12 +49,10 @@ public sealed class Conversation public bool RequiresSampling => _requiredEpoch == Executor.Epoch; #region construction/destruction - internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end) + internal Conversation(BatchedExecutor batch, LLamaSeqId id) { ConversationId = id; Executor = batch; - - _end = end; } /// @@ -98,16 +99,24 @@ public sealed class Conversation { AssertNotDisposed(); - if (RequiresInference) - throw new CannotForkWhileRequiresInferenceException(); - // Create a new conversation which references the current position in this one - var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) + var c = new Conversation(Executor, Executor.GetNextSequenceId()) { - _batchIndex = _batchIndex, + // Because these values are copied to the forked conversation it means that it will share the exact same output + // logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those + // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures + // they both copy the logits before the next sampling run, to fix this issue. _requiredEpoch = _requiredEpoch, + _batchIndex = _batchIndex, + _forked = true, + + _end = _end, }; + // Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked + // conversation doesn't share logits with this one. + _forked = true; + // Assign tokens to the new sequence NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); @@ -131,7 +140,14 @@ public sealed class Conversation if (_requiredEpoch > Executor.Epoch) throw new CannotSampleRequiresInferenceException(); - return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + + // If necessary copy the span, to protect it from modification. This is only done when + // this conversation has been forked in this epoch. + if (_forked) + span = span.ToArray(); + + return span; } #endregion @@ -162,20 +178,56 @@ public sealed class Conversation /// /// /// - public void Prompt(IReadOnlyList tokens) + /// + public void Prompt(List tokens) + { + AssertCanBePrompted(); + +#if NET6_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + Prompt(span); +#else + // Borrow an array and copy tokens into it + var arr = ArrayPool.Shared.Rent(tokens.Count); + try + { + for (var i = 0; i < tokens.Count; i++) + arr[i] = tokens[i]; + + Prompt(arr.AsSpan()); + } + finally + { + ArrayPool.Shared.Return(arr); + } +#endif + } + + /// + /// Add tokens to this conversation + /// + /// + /// + /// + /// + public void Prompt(ReadOnlySpan tokens) { AssertCanBePrompted(); // No point doing anything if there is no actual prompt! - if (tokens.Count == 0) + if (tokens.Length == 0) return; // Add the prompt to the batch - for (var i = 0; i < tokens.Count; i++) - _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); + for (var i = 0; i < tokens.Length; i++) + _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1; + + // Unset the forked flag. Since this conversation has just been prompted it's no longer + // sharing anything with any other conversations. + _forked = false; } /// @@ -184,16 +236,16 @@ public sealed class Conversation /// /// /// - /// + /// public void Prompt(LLamaToken token) { AssertCanBePrompted(); - // Add this token as input - _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true); - - // Mark this conversation as needing inference/sampling - _requiredEpoch = Executor.Epoch + 1; + unsafe + { + Span span = stackalloc LLamaToken[1] { token }; + Prompt(span); + } } #endregion