| @@ -52,18 +52,11 @@ public class BatchedDecoding | |||||
| return; | return; | ||||
| } | } | ||||
| using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); | |||||
| var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); | |||||
| // evaluate the initial prompt | // evaluate the initial prompt | ||||
| for (var i = 0; i < prompt_tokens.Length; i++) | for (var i = 0; i < prompt_tokens.Length; i++) | ||||
| batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false); | |||||
| Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length); | |||||
| // llama_decode will output logits only for the last token of the prompt | |||||
| unsafe | |||||
| { | |||||
| batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; | |||||
| } | |||||
| batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||||
| if (context.NativeHandle.Decode(batch) != 0) | if (context.NativeHandle.Decode(batch) != 0) | ||||
| { | { | ||||
| @@ -75,7 +68,7 @@ public class BatchedDecoding | |||||
| // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | ||||
| for (var i = 1; i < n_parallel; ++i) | for (var i = 1; i < n_parallel; ++i) | ||||
| { | { | ||||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); | |||||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); | |||||
| } | } | ||||
| if (n_parallel > 1) | if (n_parallel > 1) | ||||
| @@ -88,9 +81,9 @@ public class BatchedDecoding | |||||
| // we need this to determine which logits to sample from | // we need this to determine which logits to sample from | ||||
| List<int> i_batch = new(); | List<int> i_batch = new(); | ||||
| for (var i = 0; i < n_parallel; i++) | for (var i = 0; i < n_parallel; i++) | ||||
| i_batch.Add(batch.NativeBatch.n_tokens - 1); | |||||
| i_batch.Add(batch.TokenCount - 1); | |||||
| var n_cur = batch.NativeBatch.n_tokens; | |||||
| var n_cur = batch.TokenCount; | |||||
| var n_decode = 0; | var n_decode = 0; | ||||
| var streams = new List<LLamaToken>[n_parallel]; | var streams = new List<LLamaToken>[n_parallel]; | ||||
| @@ -133,7 +126,7 @@ public class BatchedDecoding | |||||
| streams[i].Add(new_token_id); | streams[i].Add(new_token_id); | ||||
| i_batch[i] = batch.NativeBatch.n_tokens; | |||||
| i_batch[i] = batch.TokenCount; | |||||
| // push this new token for next evaluation | // push this new token for next evaluation | ||||
| batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | ||||
| @@ -142,7 +135,7 @@ public class BatchedDecoding | |||||
| } | } | ||||
| // all streams are finished | // all streams are finished | ||||
| if (batch.NativeBatch.n_tokens == 0) | |||||
| if (batch.TokenCount == 0) | |||||
| { | { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -0,0 +1,121 @@ | |||||
| using System; | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// A batch allows submitting multiple tokens to multiple sequences simultaneously | |||||
| /// </summary> | |||||
| public class LLamaBatch | |||||
| { | |||||
| private readonly byte[] _logits; | |||||
| private readonly LLamaToken[] _tokens; | |||||
| private readonly LLamaPos[] _positions; | |||||
| private readonly int[] _sequenceIdCount; | |||||
| private readonly LLamaSeqId[][] _sequenceIds; | |||||
| private readonly IntPtr[] _sequenceIdsPtrs; | |||||
| /// <summary> | |||||
| /// The number of tokens in this batch | |||||
| /// </summary> | |||||
| public int TokenCount { get; private set; } | |||||
| /// <summary> | |||||
| /// Create a new batch for submitting inputs to llama.cpp | |||||
| /// </summary> | |||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="n_seq_max"></param> | |||||
| public LLamaBatch(int n_tokens, int n_seq_max) | |||||
| { | |||||
| _logits = new byte[n_tokens]; | |||||
| _tokens = new LLamaToken[n_tokens]; | |||||
| _positions = new LLamaPos[n_tokens]; | |||||
| _sequenceIdCount = new int[n_tokens]; | |||||
| _sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length]; | |||||
| _sequenceIds = new LLamaSeqId[n_tokens][]; | |||||
| for (var i = 0; i < _sequenceIds.Length; i++) | |||||
| _sequenceIds[i] = new LLamaSeqId[n_seq_max]; | |||||
| } | |||||
| internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | |||||
| { | |||||
| // This group holds all of the memory pins | |||||
| var group = new GroupDisposable(); | |||||
| unsafe | |||||
| { | |||||
| batch = new LLamaNativeBatch | |||||
| { | |||||
| n_tokens = TokenCount, | |||||
| logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer, | |||||
| n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer, | |||||
| pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer, | |||||
| seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer, | |||||
| // embd is not currently supported, so this is always null! | |||||
| embd = null, | |||||
| // Note that if embd is **not null** then this will be null! | |||||
| tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer, | |||||
| }; | |||||
| // Create pointers to each of the arrays in turns | |||||
| for (var i = 0; i < _sequenceIdsPtrs.Length; i++) | |||||
| _sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer; | |||||
| } | |||||
| return group; | |||||
| } | |||||
| /// <summary> | |||||
| /// Add a single token to the batch at the same position in several sequences | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | |||||
| /// <param name="token">The token to add</param> | |||||
| /// <param name="pos">The position to add it att</param> | |||||
| /// <param name="sequences">The set of sequences to add this token to</param> | |||||
| /// <param name="logits"></param> | |||||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| { | |||||
| _tokens[TokenCount] = token; | |||||
| _positions[TokenCount] = pos; | |||||
| _sequenceIdCount[TokenCount] = sequences.Length; | |||||
| for (var i = 0; i < sequences.Length; i++) | |||||
| _sequenceIds[TokenCount][i] = sequences[i]; | |||||
| _logits[TokenCount] = Convert.ToByte(logits); | |||||
| TokenCount++; | |||||
| } | |||||
| /// <summary> | |||||
| /// Add a single token to the batch at a certain position for a single sequences | |||||
| /// </summary> | |||||
| /// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | |||||
| /// <param name="token">The token to add</param> | |||||
| /// <param name="pos">The position to add it att</param> | |||||
| /// <param name="sequence">The sequence to add this token to</param> | |||||
| /// <param name="logits"></param> | |||||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||||
| { | |||||
| // Create a temporary span to contain 1 item without allocating | |||||
| Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | |||||
| sequences[0] = sequence; | |||||
| // Add it | |||||
| LLamaBatchAdd(token, pos, sequences, logits); | |||||
| } | |||||
| /// <summary> | |||||
| /// Set TokenCount to zero for this batch | |||||
| /// </summary> | |||||
| public void LLamaBatchClear() | |||||
| { | |||||
| TokenCount = 0; | |||||
| } | |||||
| } | |||||
| @@ -1,158 +0,0 @@ | |||||
| using System; | |||||
| namespace LLama.Native; | |||||
| /// <summary> | |||||
| /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. | |||||
| /// </summary> | |||||
| public sealed class LLamaBatchSafeHandle | |||||
| : SafeLLamaHandleBase | |||||
| { | |||||
| private readonly int _embd; | |||||
| /// <summary> | |||||
| /// Get the native llama_batch struct | |||||
| /// </summary> | |||||
| public LLamaNativeBatch NativeBatch; | |||||
| /// <summary> | |||||
| /// the token ids of the input (used when embd is NULL) | |||||
| /// </summary> | |||||
| public Span<LLamaToken> Token | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (_embd != 0) | |||||
| return new Span<LLamaToken>(null, 0); | |||||
| else | |||||
| return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) | |||||
| /// </summary> | |||||
| public Span<LLamaToken> Embed | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) | |||||
| // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||||
| if (_embd != 0) | |||||
| return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd); | |||||
| else | |||||
| return new Span<LLamaToken>(null, 0); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// the positions of the respective token in the sequence | |||||
| /// </summary> | |||||
| public Span<LLamaPos> Pos | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// the sequence to which the respective token belongs | |||||
| /// </summary> | |||||
| public Span<LLamaSeqId> Sequence_ID | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// if zero, the logits for the respective token will not be output | |||||
| /// </summary> | |||||
| public Span<byte> Logits | |||||
| { | |||||
| get | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a safe handle owning a `LLamaNativeBatch` | |||||
| /// </summary> | |||||
| /// <param name="batch"></param> | |||||
| /// <param name="embd"></param> | |||||
| public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) | |||||
| : base((nint)1) | |||||
| { | |||||
| _embd = embd; | |||||
| NativeBatch = batch; | |||||
| } | |||||
| /// <summary> | |||||
| /// Call `llama_batch_init` and create a new batch | |||||
| /// </summary> | |||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="embd"></param> | |||||
| /// <param name="n_seq_max"></param> | |||||
| /// <returns></returns> | |||||
| public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) | |||||
| { | |||||
| var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); | |||||
| return new LLamaBatchSafeHandle(batch, embd); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| NativeApi.llama_batch_free(NativeBatch); | |||||
| NativeBatch = default; | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| /// <summary> | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||||
| /// </summary> | |||||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| NativeBatch.token[NativeBatch.n_tokens] = token; | |||||
| NativeBatch.pos[NativeBatch.n_tokens] = pos; | |||||
| NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; | |||||
| for (var i = 0; i < sequences.Length; i++) | |||||
| NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; | |||||
| NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); | |||||
| NativeBatch.n_tokens++; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 | |||||
| /// </summary> | |||||
| public void LLamaBatchClear() | |||||
| { | |||||
| NativeBatch.n_tokens = 0; | |||||
| } | |||||
| } | |||||
| @@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch | |||||
| /// <summary> | /// <summary> | ||||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | ||||
| /// </summary> | /// </summary> | ||||
| public LLamaToken* token; | |||||
| public LLamaToken* tokens; | |||||
| /// <summary> | /// <summary> | ||||
| /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | ||||
| @@ -8,6 +8,11 @@ namespace LLama.Native; | |||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public record struct LLamaSeqId | public record struct LLamaSeqId | ||||
| { | { | ||||
| /// <summary> | |||||
| /// LLamaSeqId with value 0 | |||||
| /// </summary> | |||||
| public static readonly LLamaSeqId Zero = new LLamaSeqId(0); | |||||
| /// <summary> | /// <summary> | ||||
| /// The raw value | /// The raw value | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,5 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -198,9 +197,10 @@ namespace LLama.Native | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | ||||
| /// - < 0: error<br /> | /// - < 0: error<br /> | ||||
| /// </returns> | /// </returns> | ||||
| public int Decode(LLamaBatchSafeHandle batch) | |||||
| public int Decode(LLamaBatch batch) | |||||
| { | { | ||||
| return NativeApi.llama_decode(this, batch.NativeBatch); | |||||
| using (batch.ToNativeBatch(out var nb)) | |||||
| return NativeApi.llama_decode(this, nb); | |||||
| } | } | ||||
| #region state | #region state | ||||