|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- using System;
- using System.Threading;
- using System.Threading.Tasks;
- using LLama.Abstractions;
- using LLama.Native;
-
- namespace LLama.Batched;
-
- /// <summary>
- /// A batched executor that can infer multiple separate "conversations" simultaneously.
- /// </summary>
- public sealed class BatchedExecutor
- : IDisposable
- {
- private int _nextSequenceId;
-
- internal LLamaBatch Batch { get; }
-
- /// <summary>
- /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
- /// whether they're waiting for inference, or can be sampled.
- /// </summary>
- internal ulong Epoch { get; private set; }
-
- /// <summary>
- /// The <see cref="LLamaContext"/> this executor is using
- /// </summary>
- public LLamaContext Context { get; }
-
- /// <summary>
- /// The <see cref="LLamaWeights"/> this executor is using
- /// </summary>
- public LLamaWeights Model { get; }
-
- /// <summary>
- /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
- /// </summary>
- public int BatchedTokenCount => Batch.TokenCount;
-
- /// <summary>
- /// Check if this executor has been disposed.
- /// </summary>
- public bool IsDisposed { get; private set; }
-
- /// <summary>
- /// Create a new batched executor
- /// </summary>
- /// <param name="model">The model to use</param>
- /// <param name="contextParams">Parameters to create a new context</param>
- public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
- {
- Model = model;
- Batch = new LLamaBatch();
- Context = model.CreateContext(contextParams);
- Epoch = 1;
- }
-
- /// <summary>
- /// Start a new <see cref="Conversation"/>
- /// </summary>
- /// <returns></returns>
- public Conversation Create()
- {
- if (IsDisposed)
- throw new ObjectDisposedException(nameof(BatchedExecutor));
-
- return new Conversation(this, GetNextSequenceId());
- }
-
- /// <summary>
- /// Load a conversation that was previously saved to a file. Once loaded the conversation will
- /// need to be prompted.
- /// </summary>
- /// <param name="filepath"></param>
- /// <returns></returns>
- /// <exception cref="ObjectDisposedException"></exception>
- public Conversation Load(string filepath)
- {
- if (IsDisposed)
- throw new ObjectDisposedException(nameof(BatchedExecutor));
-
- var conversation = Create();
- conversation.Load(filepath);
- return conversation;
- }
-
- /// <summary>
- /// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
- /// </summary>
- /// <param name="state"></param>
- /// <returns></returns>
- /// <exception cref="ObjectDisposedException"></exception>
- public Conversation Load(Conversation.State state)
- {
- if (IsDisposed)
- throw new ObjectDisposedException(nameof(BatchedExecutor));
-
- var conversation = Create();
- conversation.Load(state);
- return conversation;
- }
-
- /// <summary>
- /// Run inference for all conversations in the batch which have pending tokens.
- ///
- /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
- /// threads and running inference again.
- /// </summary>
- public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
- {
- if (IsDisposed)
- throw new ObjectDisposedException(nameof(BatchedExecutor));
-
- var status = await Context.DecodeAsync(Batch, cancellation);
-
- // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
- // be called again after a warning (e.g. NoKvSlot).
- if (status == DecodeResult.Ok)
- {
- Epoch++;
- Batch.Clear();
- }
-
- return status;
- }
-
- /// <inheritdoc />
- public void Dispose()
- {
- if (IsDisposed)
- return;
- IsDisposed = true;
-
- Context.Dispose();
- }
-
- internal LLamaSeqId GetNextSequenceId()
- {
- return checked((LLamaSeqId)_nextSequenceId++);
- }
- }
|