| @@ -52,18 +52,11 @@ public class BatchedDecoding | |||
| return; | |||
| } | |||
| using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); | |||
| var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); | |||
| // evaluate the initial prompt | |||
| for (var i = 0; i < prompt_tokens.Length; i++) | |||
| batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false); | |||
| Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length); | |||
| // llama_decode will output logits only for the last token of the prompt | |||
| unsafe | |||
| { | |||
| batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; | |||
| } | |||
| batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| { | |||
| @@ -75,7 +68,7 @@ public class BatchedDecoding | |||
| // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | |||
| for (var i = 1; i < n_parallel; ++i) | |||
| { | |||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); | |||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); | |||
| } | |||
| if (n_parallel > 1) | |||
| @@ -88,9 +81,9 @@ public class BatchedDecoding | |||
| // we need this to determine which logits to sample from | |||
| List<int> i_batch = new(); | |||
| for (var i = 0; i < n_parallel; i++) | |||
| i_batch.Add(batch.NativeBatch.n_tokens - 1); | |||
| i_batch.Add(batch.TokenCount - 1); | |||
| var n_cur = batch.NativeBatch.n_tokens; | |||
| var n_cur = batch.TokenCount; | |||
| var n_decode = 0; | |||
| var streams = new List<LLamaToken>[n_parallel]; | |||
| @@ -133,7 +126,7 @@ public class BatchedDecoding | |||
| streams[i].Add(new_token_id); | |||
| i_batch[i] = batch.NativeBatch.n_tokens; | |||
| i_batch[i] = batch.TokenCount; | |||
| // push this new token for next evaluation | |||
| batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||
| @@ -142,7 +135,7 @@ public class BatchedDecoding | |||
| } | |||
| // all streams are finished | |||
| if (batch.NativeBatch.n_tokens == 0) | |||
| if (batch.TokenCount == 0) | |||
| { | |||
| break; | |||
| } | |||
| @@ -0,0 +1,121 @@ | |||
| using System; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// A batch allows submitting multiple tokens to multiple sequences simultaneously | |||
| /// </summary> | |||
| public class LLamaBatch | |||
| { | |||
| private readonly byte[] _logits; | |||
| private readonly LLamaToken[] _tokens; | |||
| private readonly LLamaPos[] _positions; | |||
| private readonly int[] _sequenceIdCount; | |||
| private readonly LLamaSeqId[][] _sequenceIds; | |||
| private readonly IntPtr[] _sequenceIdsPtrs; | |||
| /// <summary> | |||
| /// The number of tokens in this batch | |||
| /// </summary> | |||
| public int TokenCount { get; private set; } | |||
| /// <summary> | |||
| /// Create a new batch for submitting inputs to llama.cpp | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="n_seq_max"></param> | |||
| public LLamaBatch(int n_tokens, int n_seq_max) | |||
| { | |||
| _logits = new byte[n_tokens]; | |||
| _tokens = new LLamaToken[n_tokens]; | |||
| _positions = new LLamaPos[n_tokens]; | |||
| _sequenceIdCount = new int[n_tokens]; | |||
| _sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length]; | |||
| _sequenceIds = new LLamaSeqId[n_tokens][]; | |||
| for (var i = 0; i < _sequenceIds.Length; i++) | |||
| _sequenceIds[i] = new LLamaSeqId[n_seq_max]; | |||
| } | |||
| internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | |||
| { | |||
| // This group holds all of the memory pins | |||
| var group = new GroupDisposable(); | |||
| unsafe | |||
| { | |||
| batch = new LLamaNativeBatch | |||
| { | |||
| n_tokens = TokenCount, | |||
| logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer, | |||
| n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer, | |||
| pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer, | |||
| seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer, | |||
| // embd is not currently supported, so this is always null! | |||
| embd = null, | |||
| // Note that if embd is **not null** then this will be null! | |||
| tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer, | |||
| }; | |||
| // Create pointers to each of the arrays in turns | |||
| for (var i = 0; i < _sequenceIdsPtrs.Length; i++) | |||
| _sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer; | |||
| } | |||
| return group; | |||
| } | |||
| /// <summary> | |||
| /// Add a single token to the batch at the same position in several sequences | |||
| /// </summary> | |||
| /// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | |||
| /// <param name="token">The token to add</param> | |||
| /// <param name="pos">The position to add it att</param> | |||
| /// <param name="sequences">The set of sequences to add this token to</param> | |||
| /// <param name="logits"></param> | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| _tokens[TokenCount] = token; | |||
| _positions[TokenCount] = pos; | |||
| _sequenceIdCount[TokenCount] = sequences.Length; | |||
| for (var i = 0; i < sequences.Length; i++) | |||
| _sequenceIds[TokenCount][i] = sequences[i]; | |||
| _logits[TokenCount] = Convert.ToByte(logits); | |||
| TokenCount++; | |||
| } | |||
| /// <summary> | |||
| /// Add a single token to the batch at a certain position for a single sequences | |||
| /// </summary> | |||
| /// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | |||
| /// <param name="token">The token to add</param> | |||
| /// <param name="pos">The position to add it att</param> | |||
| /// <param name="sequence">The sequence to add this token to</param> | |||
| /// <param name="logits"></param> | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||
| { | |||
| // Create a temporary span to contain 1 item without allocating | |||
| Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | |||
| sequences[0] = sequence; | |||
| // Add it | |||
| LLamaBatchAdd(token, pos, sequences, logits); | |||
| } | |||
| /// <summary> | |||
| /// Set TokenCount to zero for this batch | |||
| /// </summary> | |||
| public void LLamaBatchClear() | |||
| { | |||
| TokenCount = 0; | |||
| } | |||
| } | |||
| @@ -1,158 +0,0 @@ | |||
| using System; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. | |||
| /// </summary> | |||
| public sealed class LLamaBatchSafeHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| private readonly int _embd; | |||
| /// <summary> | |||
| /// Get the native llama_batch struct | |||
| /// </summary> | |||
| public LLamaNativeBatch NativeBatch; | |||
| /// <summary> | |||
| /// the token ids of the input (used when embd is NULL) | |||
| /// </summary> | |||
| public Span<LLamaToken> Token | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| if (_embd != 0) | |||
| return new Span<LLamaToken>(null, 0); | |||
| else | |||
| return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | |||
| /// </summary> | |||
| public Span<LLamaToken> Embed | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) | |||
| // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||
| if (_embd != 0) | |||
| return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd); | |||
| else | |||
| return new Span<LLamaToken>(null, 0); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// the positions of the respective token in the sequence | |||
| /// </summary> | |||
| public Span<LLamaPos> Pos | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// the sequence to which the respective token belongs | |||
| /// </summary> | |||
| public Span<LLamaSeqId> Sequence_ID | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// if zero, the logits for the respective token will not be output | |||
| /// </summary> | |||
| public Span<byte> Logits | |||
| { | |||
| get | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Create a safe handle owning a `LLamaNativeBatch` | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <param name="embd"></param> | |||
| public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) | |||
| : base((nint)1) | |||
| { | |||
| _embd = embd; | |||
| NativeBatch = batch; | |||
| } | |||
| /// <summary> | |||
| /// Call `llama_batch_init` and create a new batch | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="embd"></param> | |||
| /// <param name="n_seq_max"></param> | |||
| /// <returns></returns> | |||
| public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) | |||
| { | |||
| var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); | |||
| return new LLamaBatchSafeHandle(batch, embd); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| NativeApi.llama_batch_free(NativeBatch); | |||
| NativeBatch = default; | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||
| /// </summary> | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| unsafe | |||
| { | |||
| NativeBatch.token[NativeBatch.n_tokens] = token; | |||
| NativeBatch.pos[NativeBatch.n_tokens] = pos; | |||
| NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; | |||
| for (var i = 0; i < sequences.Length; i++) | |||
| NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; | |||
| NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); | |||
| NativeBatch.n_tokens++; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 | |||
| /// </summary> | |||
| public void LLamaBatchClear() | |||
| { | |||
| NativeBatch.n_tokens = 0; | |||
| } | |||
| } | |||
| @@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch | |||
| /// <summary> | |||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | |||
| /// </summary> | |||
| public LLamaToken* token; | |||
| public LLamaToken* tokens; | |||
| /// <summary> | |||
| /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | |||
| @@ -8,6 +8,11 @@ namespace LLama.Native; | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public record struct LLamaSeqId | |||
| { | |||
| /// <summary> | |||
| /// LLamaSeqId with value 0 | |||
| /// </summary> | |||
| public static readonly LLamaSeqId Zero = new LLamaSeqId(0); | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| @@ -1,5 +1,4 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -198,9 +197,10 @@ namespace LLama.Native | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public int Decode(LLamaBatchSafeHandle batch) | |||
| public int Decode(LLamaBatch batch) | |||
| { | |||
| return NativeApi.llama_decode(this, batch.NativeBatch); | |||
| using (batch.ToNativeBatch(out var nb)) | |||
| return NativeApi.llama_decode(this, nb); | |||
| } | |||
| #region state | |||