diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 1d55ff12..c51c49be 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -52,7 +52,7 @@ public class BatchedDecoding return; } - var batch = new LLamaBatch(1); + var batch = new LLamaBatch(); // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index a7ca923b..20e14530 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -38,7 +38,7 @@ public class LLamaBatch { // These can both be grown later, start off with reasonable numbers. const int n_tokens = 128; - const int n_seq_max = 4; + const int n_seq_max = 1; MaxSequences = n_seq_max; TokenCapacity = n_tokens; @@ -77,9 +77,9 @@ public class LLamaBatch } } - private void GrowMaxSequences() + private void GrowMaxSequences(int atLeast) { - var n_seq = MaxSequences * 2; + var n_seq = Math.Max(MaxSequences * 2, atLeast); MaxSequences = n_seq; for (var i = 0; i < _sequenceIds.Length; i++) @@ -130,7 +130,7 @@ public class LLamaBatch if (TokenCount == TokenCapacity) GrowTokenCapacity(); if (sequences.Length > MaxSequences) - GrowMaxSequences(); + GrowMaxSequences(sequences.Length); _tokens[TokenCount] = token; _positions[TokenCount] = pos;