using System;
namespace LLama.Native;
using llama_token = Int32;
///
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
///
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
///
/// Get the native llama_batch struct
///
public LLamaNativeBatch NativeBatch;
///
/// the token ids of the input (used when embd is NULL)
///
public Span Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span(null, 0);
else
return new Span(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
///
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
///
public Span Embed
{
get
{
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span(null, 0);
}
}
}
///
/// the positions of the respective token in the sequence
///
public Span Pos
{
get
{
unsafe
{
return new Span(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
///
/// the sequence to which the respective token belongs
///
public Span Sequence_ID
{
get
{
unsafe
{
return new Span(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
///
/// if zero, the logits for the respective token will not be output
///
public Span Logits
{
get
{
unsafe
{
return new Span(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}
///
/// Create a safe handle owning a `LLamaNativeBatch`
///
///
///
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
NativeBatch = batch;
}
///
/// Call `llama_batch_init` and create a new batch
///
///
///
///
///
public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
return new LLamaBatchSafeHandle(batch, embd);
}
///
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(NativeBatch);
NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}
///
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
///
public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
{
unsafe
{
NativeBatch.token[NativeBatch.n_tokens] = token;
NativeBatch.pos[NativeBatch.n_tokens] = pos;
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
NativeBatch.n_tokens++;
}
}
///
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
///
public void LLamaBatchClear()
{
NativeBatch.n_tokens = 0;
}
}