using System; using System.Collections.Generic; using System.Runtime.InteropServices; namespace LLama.Native; /// /// A batch allows submitting multiple tokens to multiple sequences simultaneously /// public class LLamaBatch { private byte[] _logits; private LLamaToken[] _tokens; private LLamaPos[] _positions; private int[] _sequenceIdCount; private LLamaSeqId[][] _sequenceIds; private IntPtr[] _sequenceIdsPtrs; /// /// Keep track of the index of existing token/position combos in the batch /// private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new(); /// /// Keep a list of where logits can be sampled from /// private readonly List<(LLamaSeqId, int)> _logitPositions = new(); /// /// Get the number of logit positions that will be generated from this batch /// internal int LogitPositionCount => _logitPositions.Count; /// /// The number of tokens in this batch /// public int TokenCount { get; private set; } /// /// Maximum number of tokens that can be added to this batch (automatically grows if exceeded) /// private int TokenCapacity { get; set; } /// /// Maximum number of sequences a token can be assigned to (automatically grows if exceeded) /// public int SequenceCapacity { get; private set; } /// /// Create a new batch for submitting inputs to llama.cpp /// public LLamaBatch() { // These can both be grown later, start off with reasonable numbers. const int n_tokens = 128; const int n_seq_max = 1; SequenceCapacity = n_seq_max; TokenCapacity = n_tokens; _logits = new byte[n_tokens]; _tokens = new LLamaToken[n_tokens]; _positions = new LLamaPos[n_tokens]; _sequenceIdCount = new int[n_tokens]; _sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length]; _sequenceIds = new LLamaSeqId[n_tokens][]; for (var i = 0; i < _sequenceIds.Length; i++) _sequenceIds[i] = new LLamaSeqId[SequenceCapacity]; } #region grow private void GrowTokenCapacity() { var n_tokens = TokenCount * 2; TokenCapacity = n_tokens; Array.Resize(ref _logits, n_tokens); Array.Resize(ref _tokens, n_tokens); Array.Resize(ref _positions, n_tokens); Array.Resize(ref _sequenceIdCount, n_tokens); Array.Resize(ref _sequenceIdsPtrs, n_tokens); Array.Resize(ref _sequenceIds, n_tokens); for (int i = 0; i < _sequenceIds.Length; i++) { // Growing the array filled elements with null, temporarily violating the nullability contract! // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract if (_sequenceIds[i] == null) _sequenceIds[i] = new LLamaSeqId[SequenceCapacity]; } } private void GrowMaxSequences(int atLeast) { var n_seq = Math.Max(SequenceCapacity * 2, atLeast); SequenceCapacity = n_seq; for (var i = 0; i < _sequenceIds.Length; i++) Array.Resize(ref _sequenceIds[i], SequenceCapacity); } #endregion internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) { // This group holds all of the memory pins var group = new GroupDisposable(); unsafe { batch = new LLamaNativeBatch { n_tokens = TokenCount, logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer, n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer, pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer, seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer, // embd is not currently supported, so this is always null! embd = null, // Note that if embd is **not null** then this will be null! tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer, }; // Create pointers to each of the arrays in turns for (var i = 0; i < _sequenceIdsPtrs.Length; i++) _sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer; } return group; } #region add /// /// Add a single token to the batch at the same position in several sequences /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// The token to add /// The position to add it att /// The set of sequences to add this token to /// /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this // sequence ID to the list. // Do **not** do this if this token wants logits, to prevent logits being shared between sequences. if (!logits && _index.TryGetValue((token, pos), out var existingIndex)) { if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); foreach (var sequence in sequences) { _sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence; _sequenceIdCount[existingIndex]++; } return existingIndex; } // Couldn't find this token/position combo anywhere in the batch. Add a new item. // Grow capacity as necessary if (TokenCount == TokenCapacity) GrowTokenCapacity(); if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); // Store the position in the index, so it can be found later. // We need to check that it's not already there in case we skipped the check above (because logits is true). if (!_index.ContainsKey((token, pos))) _index.Add((token, pos), TokenCount); // Add the items to the arrays _tokens[TokenCount] = token; _positions[TokenCount] = pos; _sequenceIdCount[TokenCount] = sequences.Length; for (var i = 0; i < sequences.Length; i++) _sequenceIds[TokenCount][i] = sequences[i]; _logits[TokenCount] = Convert.ToByte(logits); // Store this position in the logits lookup if necessary if (logits) { foreach (var sequence in sequences) _logitPositions.Add((sequence, TokenCount)); } return TokenCount++; } /// /// Add a single token to the batch at the same position in several sequences /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// The token to add /// The position to add it att /// The set of sequences to add this token to /// /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, List sequences, bool logits) { #if NET5_0_OR_GREATER var seqSpan = CollectionsMarshal.AsSpan(sequences); return Add(token, pos, seqSpan, logits); #else // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. var rented = System.Buffers.ArrayPool.Shared.Rent(sequences.Count); try { sequences.CopyTo(rented, 0); return Add(token, pos, rented.AsSpan(0, sequences.Count), logits); } finally { System.Buffers.ArrayPool.Shared.Return(rented); } #endif } /// /// Add a single token to the batch at a certain position for a single sequences /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// The token to add /// The position to add it att /// The sequence to add this token to /// /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) { // Create a temporary span to contain 1 item without allocating Span sequences = stackalloc LLamaSeqId[1]; sequences[0] = sequence; // Add it return Add(token, pos, sequences, logits); } /// /// Add a range of tokens to a single sequence, start at the given position. /// /// The tokens to add /// The starting position to add tokens at /// The sequence to add this token to /// Whether the final token should generate logits /// The index that the final token was added at. Use this for GetLogitsIth public int AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) { var last = -1; for (var i = 0; i < tokens.Length; i++) { var logits = (i == tokens.Length - 1) & logitsLast; last = Add(tokens[i], start.Value + i, sequence, logits); } return last; } #endregion /// /// Set TokenCount to zero for this batch /// public void Clear() { TokenCount = 0; _index.Clear(); _logitPositions.Clear(); } /// /// Get the positions where logits can be sampled from /// /// internal Span<(LLamaSeqId, int)> GetLogitPositions(Span<(LLamaSeqId, int)> dest) { for (var i = 0; i < _logitPositions.Count; i++) dest[i] = _logitPositions[i]; return dest.Slice(0, _logitPositions.Count); } }