- Added a `DecodeAsync` overload which runs the work in a task - Replaced some `NativeHandle` usage in `BatchedDecoding` with higher level equivalents. - Made the `LLamaBatch` grow when token capacity is exceeded, removing the need to manage token capacity externally.tags/v0.10.0
| @@ -52,13 +52,13 @@ public class BatchedDecoding | |||||
| return; | return; | ||||
| } | } | ||||
| var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); | |||||
| var batch = new LLamaBatch(1); | |||||
| // evaluate the initial prompt | // evaluate the initial prompt | ||||
| for (var i = 0; i < prompt_tokens.Length; i++) | for (var i = 0; i < prompt_tokens.Length; i++) | ||||
| batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||||
| batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||||
| if (context.NativeHandle.Decode(batch) != 0) | |||||
| if (await context.DecodeAsync(batch) != 0) | |||||
| { | { | ||||
| await Console.Error.WriteLineAsync("llama_decode failed"); | await Console.Error.WriteLineAsync("llama_decode failed"); | ||||
| return; | return; | ||||
| @@ -97,7 +97,7 @@ public class BatchedDecoding | |||||
| timer.Start(); | timer.Start(); | ||||
| while (n_cur <= n_len) | while (n_cur <= n_len) | ||||
| { | { | ||||
| batch.LLamaBatchClear(); | |||||
| batch.Clear(); | |||||
| for (var i = 0; i < n_parallel; i++) | for (var i = 0; i < n_parallel; i++) | ||||
| { | { | ||||
| @@ -129,7 +129,7 @@ public class BatchedDecoding | |||||
| i_batch[i] = batch.TokenCount; | i_batch[i] = batch.TokenCount; | ||||
| // push this new token for next evaluation | // push this new token for next evaluation | ||||
| batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||||
| batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||||
| n_decode++; | n_decode++; | ||||
| } | } | ||||
| @@ -143,7 +143,7 @@ public class BatchedDecoding | |||||
| n_cur++; | n_cur++; | ||||
| // evaluate the current batch with the transformer model | // evaluate the current batch with the transformer model | ||||
| if (context.NativeHandle.Decode(batch) != 0) | |||||
| if (await context.DecodeAsync(batch) != 0) | |||||
| { | { | ||||
| await Console.Error.WriteLineAsync("failed to eval"); | await Console.Error.WriteLineAsync("failed to eval"); | ||||
| return; | return; | ||||
| @@ -40,7 +40,7 @@ public sealed class BeamTests | |||||
| var initial_tokens = context.Tokenize(prompt); | var initial_tokens = context.Tokenize(prompt); | ||||
| result.Append(prompt); | result.Append(prompt); | ||||
| context.Eval(initial_tokens, 0); | |||||
| context.Eval(initial_tokens.AsSpan(), 0); | |||||
| NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | ||||
| { | { | ||||
| @@ -36,7 +36,7 @@ namespace LLama.Unittest | |||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| const string question = "Question. what is a cat?\nAnswer: "; | |||||
| const string question = "Question. what is a cat?\nAnswer:"; | |||||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | ||||
| var timer = new Stopwatch(); | var timer = new Stopwatch(); | ||||
| @@ -8,10 +8,12 @@ using System.IO; | |||||
| using System.IO.MemoryMappedFiles; | using System.IO.MemoryMappedFiles; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Threading.Tasks; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Sampling; | using LLama.Sampling; | ||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| using System.Threading; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -344,16 +346,30 @@ namespace LLama | |||||
| #region eval overloads | #region eval overloads | ||||
| /// <summary> | /// <summary> | ||||
| /// | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens"></param> | |||||
| /// <param name="pastTokensCount"></param> | |||||
| /// <returns>The updated `pastTokensCount`.</returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| [Obsolete("use llama_decode() instead")] | |||||
| public int Eval(LLamaToken[] tokens, int pastTokensCount) | |||||
| /// <param name="batch"></param> | |||||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||||
| /// - 0: success<br /> | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||||
| /// - < 0: error<br /> | |||||
| /// </returns> | |||||
| public int Decode(LLamaBatch batch) | |||||
| { | |||||
| return NativeHandle.Decode(batch); | |||||
| } | |||||
| /// <summary> | |||||
| /// </summary> | |||||
| /// <param name="batch"></param> | |||||
| /// <param name="cancellationToken"></param> | |||||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||||
| /// - 0: success<br /> | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||||
| /// - < 0: error<br /> | |||||
| /// </returns> | |||||
| public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| return Eval(tokens.AsSpan(), pastTokensCount); | |||||
| return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -363,7 +379,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| [Obsolete("use llama_decode() instead")] | |||||
| [Obsolete("use Decode() instead")] | |||||
| public int Eval(List<LLamaToken> tokens, int pastTokensCount) | public int Eval(List<LLamaToken> tokens, int pastTokensCount) | ||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| @@ -394,20 +410,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| [Obsolete("use llama_decode() instead")] | |||||
| public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount) | |||||
| { | |||||
| return Eval(tokens.Span, pastTokensCount); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <param name="pastTokensCount"></param> | |||||
| /// <returns>The updated `pastTokensCount`.</returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| [Obsolete("use llama_decode() instead")] | |||||
| [Obsolete("use Decode() instead")] | |||||
| public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount) | public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount) | ||||
| { | { | ||||
| var total = tokens.Length; | var total = tokens.Length; | ||||
| @@ -75,7 +75,7 @@ namespace LLama | |||||
| // TODO(Rinne): deal with log of prompt | // TODO(Rinne): deal with log of prompt | ||||
| if (embed_inp_array.Length > 0) | if (embed_inp_array.Length > 0) | ||||
| Context.Eval(embed_inp_array, 0); | |||||
| Context.Eval(embed_inp_array.AsSpan(), 0); | |||||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | ||||
| if (embeddings == null) | if (embeddings == null) | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| namespace LLama.Native; | namespace LLama.Native; | ||||
| @@ -7,27 +8,42 @@ namespace LLama.Native; | |||||
| /// </summary> | /// </summary> | ||||
| public class LLamaBatch | public class LLamaBatch | ||||
| { | { | ||||
| private readonly byte[] _logits; | |||||
| private byte[] _logits; | |||||
| private readonly LLamaToken[] _tokens; | |||||
| private readonly LLamaPos[] _positions; | |||||
| private LLamaToken[] _tokens; | |||||
| private LLamaPos[] _positions; | |||||
| private readonly int[] _sequenceIdCount; | |||||
| private readonly LLamaSeqId[][] _sequenceIds; | |||||
| private readonly IntPtr[] _sequenceIdsPtrs; | |||||
| private int[] _sequenceIdCount; | |||||
| private LLamaSeqId[][] _sequenceIds; | |||||
| private IntPtr[] _sequenceIdsPtrs; | |||||
| /// <summary> | /// <summary> | ||||
| /// The number of tokens in this batch | /// The number of tokens in this batch | ||||
| /// </summary> | /// </summary> | ||||
| public int TokenCount { get; private set; } | public int TokenCount { get; private set; } | ||||
| /// <summary> | |||||
| /// Maximum number of tokens that can be added to this batch | |||||
| /// </summary> | |||||
| private int TokenCapacity { get; set; } | |||||
| /// <summary> | |||||
| /// Maximum number of sequences a token can be assigned to | |||||
| /// </summary> | |||||
| public int MaxSequences { get; private set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a new batch for submitting inputs to llama.cpp | /// Create a new batch for submitting inputs to llama.cpp | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="n_tokens"></param> | |||||
| /// <param name="n_seq_max"></param> | |||||
| public LLamaBatch(int n_tokens, int n_seq_max) | |||||
| /// <param name="n_seq_max">Max number of sequences a token can be assigned to</param> | |||||
| public LLamaBatch(int n_seq_max) | |||||
| { | { | ||||
| // The number of tokens can be grown later, start off with a reasonable guess. | |||||
| const int n_tokens = 64; | |||||
| MaxSequences = n_seq_max; | |||||
| TokenCapacity = n_tokens; | |||||
| _logits = new byte[n_tokens]; | _logits = new byte[n_tokens]; | ||||
| _tokens = new LLamaToken[n_tokens]; | _tokens = new LLamaToken[n_tokens]; | ||||
| _positions = new LLamaPos[n_tokens]; | _positions = new LLamaPos[n_tokens]; | ||||
| @@ -37,7 +53,29 @@ public class LLamaBatch | |||||
| _sequenceIds = new LLamaSeqId[n_tokens][]; | _sequenceIds = new LLamaSeqId[n_tokens][]; | ||||
| for (var i = 0; i < _sequenceIds.Length; i++) | for (var i = 0; i < _sequenceIds.Length; i++) | ||||
| _sequenceIds[i] = new LLamaSeqId[n_seq_max]; | |||||
| _sequenceIds[i] = new LLamaSeqId[MaxSequences]; | |||||
| } | |||||
| private void Grow() | |||||
| { | |||||
| var n_tokens = TokenCount * 2; | |||||
| TokenCapacity = n_tokens; | |||||
| Array.Resize(ref _logits, n_tokens); | |||||
| Array.Resize(ref _tokens, n_tokens); | |||||
| Array.Resize(ref _positions, n_tokens); | |||||
| Array.Resize(ref _sequenceIdCount, n_tokens); | |||||
| Array.Resize(ref _sequenceIdsPtrs, n_tokens); | |||||
| Array.Resize(ref _sequenceIds, n_tokens); | |||||
| for (int i = 0; i < _sequenceIds.Length; i++) | |||||
| { | |||||
| // Growing the array filled elements with null, temporarily violating the nullability contract! | |||||
| // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract | |||||
| if (_sequenceIds[i] == null) | |||||
| _sequenceIds[i] = new LLamaSeqId[MaxSequences]; | |||||
| } | |||||
| } | } | ||||
| internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | ||||
| @@ -79,8 +117,11 @@ public class LLamaBatch | |||||
| /// <param name="pos">The position to add it att</param> | /// <param name="pos">The position to add it att</param> | ||||
| /// <param name="sequences">The set of sequences to add this token to</param> | /// <param name="sequences">The set of sequences to add this token to</param> | ||||
| /// <param name="logits"></param> | /// <param name="logits"></param> | ||||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| { | { | ||||
| if (TokenCount == TokenCapacity) | |||||
| Grow(); | |||||
| _tokens[TokenCount] = token; | _tokens[TokenCount] = token; | ||||
| _positions[TokenCount] = pos; | _positions[TokenCount] = pos; | ||||
| @@ -101,20 +142,20 @@ public class LLamaBatch | |||||
| /// <param name="pos">The position to add it att</param> | /// <param name="pos">The position to add it att</param> | ||||
| /// <param name="sequence">The sequence to add this token to</param> | /// <param name="sequence">The sequence to add this token to</param> | ||||
| /// <param name="logits"></param> | /// <param name="logits"></param> | ||||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||||
| public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||||
| { | { | ||||
| // Create a temporary span to contain 1 item without allocating | // Create a temporary span to contain 1 item without allocating | ||||
| Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | ||||
| sequences[0] = sequence; | sequences[0] = sequence; | ||||
| // Add it | // Add it | ||||
| LLamaBatchAdd(token, pos, sequences, logits); | |||||
| Add(token, pos, sequences, logits); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Set TokenCount to zero for this batch | /// Set TokenCount to zero for this batch | ||||
| /// </summary> | /// </summary> | ||||
| public void LLamaBatchClear() | |||||
| public void Clear() | |||||
| { | { | ||||
| TokenCount = 0; | TokenCount = 0; | ||||
| } | } | ||||