using System; namespace LLama.Native; using llama_token = Int32; /// /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. /// public sealed class LLamaBatchSafeHandle : SafeLLamaHandleBase { private readonly int _embd; /// /// Get the native llama_batch struct /// public LLamaNativeBatch NativeBatch; /// /// the token ids of the input (used when embd is NULL) /// public Span Token { get { unsafe { if (_embd != 0) return new Span(null, 0); else return new Span(NativeBatch.token, NativeBatch.n_tokens); } } } /// /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) /// public Span Embed { get { unsafe { // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token if (_embd != 0) return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); else return new Span(null, 0); } } } /// /// the positions of the respective token in the sequence /// public Span Pos { get { unsafe { return new Span(NativeBatch.pos, NativeBatch.n_tokens); } } } /// /// the sequence to which the respective token belongs /// public Span Sequence_ID { get { unsafe { return new Span(NativeBatch.seq_id, NativeBatch.n_tokens); } } } /// /// if zero, the logits for the respective token will not be output /// public Span Logits { get { unsafe { return new Span(NativeBatch.logits, NativeBatch.n_tokens); } } } /// /// Create a safe handle owning a `LLamaNativeBatch` /// /// /// public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) : base((nint)1) { _embd = embd; NativeBatch = batch; } /// /// Call `llama_batch_init` and create a new batch /// /// /// /// /// public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) { var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); return new LLamaBatchSafeHandle(batch, embd); } /// protected override bool ReleaseHandle() { NativeApi.llama_batch_free(NativeBatch); NativeBatch = default; SetHandle(IntPtr.Zero); return true; } /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { unsafe { NativeBatch.token[NativeBatch.n_tokens] = token; NativeBatch.pos[NativeBatch.n_tokens] = pos; NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; for (var i = 0; i < sequences.Length; i++) NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); NativeBatch.n_tokens++; } } /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 /// public void LLamaBatchClear() { NativeBatch.n_tokens = 0; } }