@@ -1,5 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;
namespace LLama.Batched;
@@ -14,6 +16,7 @@ public sealed class Conversation
private LLamaPos _end;
private int _batchIndex;
private bool _disposed;
private bool _forked;
/// <summary>
/// The executor which this conversation belongs to
@@ -46,12 +49,10 @@ public sealed class Conversation
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
#region construction/destruction
internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end )
internal Conversation(BatchedExecutor batch, LLamaSeqId id)
{
ConversationId = id;
Executor = batch;
_end = end;
}
/// <summary>
@@ -98,16 +99,24 @@ public sealed class Conversation
{
AssertNotDisposed();
if (RequiresInference)
throw new CannotForkWhileRequiresInferenceException();
// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end )
var c = new Conversation(Executor, Executor.GetNextSequenceId())
{
_batchIndex = _batchIndex,
// Because these values are copied to the forked conversation it means that it will share the exact same output
// logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchIndex = _batchIndex,
_forked = true,
_end = _end,
};
// Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
// conversation doesn't share logits with this one.
_forked = true;
// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
@@ -131,7 +140,14 @@ public sealed class Conversation
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();
return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
if (_forked)
span = span.ToArray();
return span;
}
#endregion
@@ -162,20 +178,56 @@ public sealed class Conversation
/// <param name="tokens"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public void Prompt(IReadOnlyList<LLamaToken> tokens)
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(List<LLamaToken> tokens)
{
AssertCanBePrompted();
#if NET6_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Prompt(span);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
for (var i = 0; i < tokens.Count; i++)
arr[i] = tokens[i];
Prompt(arr.AsSpan());
}
finally
{
ArrayPool<LLamaToken>.Shared.Return(arr);
}
#endif
}
/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
{
AssertCanBePrompted();
// No point doing anything if there is no actual prompt!
if (tokens.Count == 0)
if (tokens.Length == 0)
return;
// Add the prompt to the batch
for (var i = 0; i < tokens.Count; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
for (var i = 0; i < tokens.Length ; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
// Unset the forked flag. Since this conversation has just been prompted it's no longer
// sharing anything with any other conversations.
_forked = false;
}
/// <summary>
@@ -184,16 +236,16 @@ public sealed class Conversation
/// <param name="token"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="InvalidOper ationException"></exception>
/// <exception cref="AlreadyPromptedConvers ationException"></exception>
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();
// Add this token as input
_batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
unsafe
{
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
Prompt(span);
}
}
#endregion