@@ -138,8 +138,10 @@ public class LLamaBatch
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
// Try to find this (token, position) combo somewhere in the batch to re-use it
if (_index.TryGetValue((token, pos), out var existingIndex))
// Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
// sequence ID to the list.
// Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
if (!logits && _index.TryGetValue((token, pos), out var existingIndex))
{
if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);
@@ -153,16 +155,18 @@ public class LLamaBatch
return existingIndex;
}
// Couldn't find this it in the batch, add a new item
// Couldn't find this token/position combo anywhere in the batch. Add a new item.
// F row capacity as necessary
// G row capacity as necessary
if (TokenCount == TokenCapacity)
GrowTokenCapacity();
if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length);
// Store the position in the index, so it can be found later
_index.Add((token, pos), TokenCount);
// Store the position in the index, so it can be found later.
// We need to check that it's not already there in case we skipped the check above (because logits is true).
if (!_index.ContainsKey((token, pos)))
_index.Add((token, pos), TokenCount);
// Add the items to the arrays
_tokens[TokenCount] = token;