using System; namespace LLama.Native; using llama_token = Int32; /// /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. /// public sealed class LLamaBatchSafeHandle : SafeLLamaHandleBase { private readonly int _embd; /// /// Get the native llama_batch struct /// public LLamaNativeBatch NativeBatch; /// /// the token ids of the input (used when embd is NULL) /// public Span Token { get { unsafe { if (_embd != 0) return new Span(null, 0); else return new Span(NativeBatch.token, NativeBatch.n_tokens); } } } /// /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) /// public Span Embed { get { unsafe { // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token if (_embd != 0) return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); else return new Span(null, 0); } } } /// /// the positions of the respective token in the sequence /// public Span Pos { get { unsafe { return new Span(NativeBatch.pos, NativeBatch.n_tokens); } } } /// /// the sequence to which the respective token belongs /// public Span Sequence_ID { get { unsafe { return new Span(NativeBatch.seq_id, NativeBatch.n_tokens); } } } /// /// if zero, the logits for the respective token will not be output /// public Span Logits { get { unsafe { return new Span(NativeBatch.logits, NativeBatch.n_tokens); } } } /// /// Create a safe handle owning a `LLamaNativeBatch` /// /// /// public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) : base((nint)1) { _embd = embd; NativeBatch = batch; } /// /// Call `llama_batch_init` and create a new batch /// /// /// /// /// public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) { var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); return new LLamaBatchSafeHandle(batch, embd); } /// protected override bool ReleaseHandle() { NativeApi.llama_batch_free(NativeBatch); NativeBatch = default; SetHandle(IntPtr.Zero); return true; } }