using System;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Native;
namespace LLama.Batched;
///
/// A batched executor that can infer multiple separate "conversations" simultaneously.
///
public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;
internal LLamaBatch Batch { get; }
///
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
/// whether they're waiting for inference, or can be sampled.
///
internal ulong Epoch { get; private set; }
///
/// The this executor is using
///
public LLamaContext Context { get; }
///
/// The this executor is using
///
public LLamaWeights Model { get; }
///
/// Get the number of tokens in the batch, waiting for to be called
///
public int BatchedTokenCount => Batch.TokenCount;
///
/// Check if this executor has been disposed.
///
public bool IsDisposed { get; private set; }
///
/// Create a new batched executor
///
/// The model to use
/// Parameters to create a new context
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
{
Model = model;
Batch = new LLamaBatch();
Context = model.CreateContext(contextParams);
Epoch = 1;
}
///
/// Finalizer for BatchedExecutor
///
~BatchedExecutor()
{
Dispose();
}
///
/// Start a new with the given prompt
///
///
///
public Conversation Prompt(string prompt)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));
var conversation = new Conversation(this, GetNextSequenceId(), 0);
conversation.Prompt(prompt);
return conversation;
}
///
/// Run inference for all conversations in the batch which have pending tokens.
///
/// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
/// threads and running inference again.
///
public async Task Infer(CancellationToken cancellation = default)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));
var status = await Context.DecodeAsync(Batch, cancellation);
// Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
// be called again after a warning (e.g. NoKvSlot).
if (status == DecodeResult.Ok)
{
Epoch++;
Batch.Clear();
}
return status;
}
///
public void Dispose()
{
if (IsDisposed)
return;
IsDisposed = true;
GC.SuppressFinalize(this);
Context.Dispose();
}
internal LLamaSeqId GetNextSequenceId()
{
return checked((LLamaSeqId)_nextSequenceId++);
}
}