diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs
index b15d445e..5d637c8e 100644
--- a/LLama/Batched/BatchedExecutor.cs
+++ b/LLama/Batched/BatchedExecutor.cs
@@ -55,14 +55,6 @@ public sealed class BatchedExecutor
Epoch = 1;
}
- ///
- /// Finalizer for BatchedExecutor
- ///
- ~BatchedExecutor()
- {
- Dispose();
- }
-
///
/// Start a new with the given prompt
///
@@ -89,7 +81,7 @@ public sealed class BatchedExecutor
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));
- return new Conversation(this, GetNextSequenceId(), 0);
+ return new Conversation(this, GetNextSequenceId());
}
///
@@ -123,8 +115,6 @@ public sealed class BatchedExecutor
return;
IsDisposed = true;
- GC.SuppressFinalize(this);
-
Context.Dispose();
}
diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs
index 29759cf9..cfdb0a1f 100644
--- a/LLama/Batched/Conversation.cs
+++ b/LLama/Batched/Conversation.cs
@@ -1,5 +1,7 @@
using System;
+using System.Buffers;
using System.Collections.Generic;
+using System.Runtime.InteropServices;
using LLama.Native;
namespace LLama.Batched;
@@ -14,6 +16,7 @@ public sealed class Conversation
private LLamaPos _end;
private int _batchIndex;
private bool _disposed;
+ private bool _forked;
///
/// The executor which this conversation belongs to
@@ -46,12 +49,10 @@ public sealed class Conversation
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
#region construction/destruction
- internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
+ internal Conversation(BatchedExecutor batch, LLamaSeqId id)
{
ConversationId = id;
Executor = batch;
-
- _end = end;
}
///
@@ -98,16 +99,24 @@ public sealed class Conversation
{
AssertNotDisposed();
- if (RequiresInference)
- throw new CannotForkWhileRequiresInferenceException();
-
// Create a new conversation which references the current position in this one
- var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
+ var c = new Conversation(Executor, Executor.GetNextSequenceId())
{
- _batchIndex = _batchIndex,
+ // Because these values are copied to the forked conversation it means that it will share the exact same output
+ // logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
+ // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
+ // they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
+ _batchIndex = _batchIndex,
+ _forked = true,
+
+ _end = _end,
};
+ // Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
+ // conversation doesn't share logits with this one.
+ _forked = true;
+
// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
@@ -131,7 +140,14 @@ public sealed class Conversation
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();
- return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
+ var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
+
+ // If necessary copy the span, to protect it from modification. This is only done when
+ // this conversation has been forked in this epoch.
+ if (_forked)
+ span = span.ToArray();
+
+ return span;
}
#endregion
@@ -162,20 +178,56 @@ public sealed class Conversation
///
///
///
- public void Prompt(IReadOnlyList tokens)
+ ///
+ public void Prompt(List tokens)
+ {
+ AssertCanBePrompted();
+
+#if NET6_0_OR_GREATER
+ var span = CollectionsMarshal.AsSpan(tokens);
+ Prompt(span);
+#else
+ // Borrow an array and copy tokens into it
+ var arr = ArrayPool.Shared.Rent(tokens.Count);
+ try
+ {
+ for (var i = 0; i < tokens.Count; i++)
+ arr[i] = tokens[i];
+
+ Prompt(arr.AsSpan());
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(arr);
+ }
+#endif
+ }
+
+ ///
+ /// Add tokens to this conversation
+ ///
+ ///
+ ///
+ ///
+ ///
+ public void Prompt(ReadOnlySpan tokens)
{
AssertCanBePrompted();
// No point doing anything if there is no actual prompt!
- if (tokens.Count == 0)
+ if (tokens.Length == 0)
return;
// Add the prompt to the batch
- for (var i = 0; i < tokens.Count; i++)
- _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
+ for (var i = 0; i < tokens.Length; i++)
+ _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
+
+ // Unset the forked flag. Since this conversation has just been prompted it's no longer
+ // sharing anything with any other conversations.
+ _forked = false;
}
///
@@ -184,16 +236,16 @@ public sealed class Conversation
///
///
///
- ///
+ ///
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();
- // Add this token as input
- _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
-
- // Mark this conversation as needing inference/sampling
- _requiredEpoch = Executor.Epoch + 1;
+ unsafe
+ {
+ Span span = stackalloc LLamaToken[1] { token };
+ Prompt(span);
+ }
}
#endregion