- Added a `DecodeAsync` overload which runs the work in a task - Replaced some `NativeHandle` usage in `BatchedDecoding` with higher level equivalents. - Made the `LLamaBatch` grow when token capacity is exceeded, removing the need to manage token capacity externally.tags/v0.10.0
| @@ -52,13 +52,13 @@ public class BatchedDecoding | |||
| return; | |||
| } | |||
| var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); | |||
| var batch = new LLamaBatch(1); | |||
| // evaluate the initial prompt | |||
| for (var i = 0; i < prompt_tokens.Length; i++) | |||
| batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||
| batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| if (await context.DecodeAsync(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("llama_decode failed"); | |||
| return; | |||
| @@ -97,7 +97,7 @@ public class BatchedDecoding | |||
| timer.Start(); | |||
| while (n_cur <= n_len) | |||
| { | |||
| batch.LLamaBatchClear(); | |||
| batch.Clear(); | |||
| for (var i = 0; i < n_parallel; i++) | |||
| { | |||
| @@ -129,7 +129,7 @@ public class BatchedDecoding | |||
| i_batch[i] = batch.TokenCount; | |||
| // push this new token for next evaluation | |||
| batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||
| batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||
| n_decode++; | |||
| } | |||
| @@ -143,7 +143,7 @@ public class BatchedDecoding | |||
| n_cur++; | |||
| // evaluate the current batch with the transformer model | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| if (await context.DecodeAsync(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("failed to eval"); | |||
| return; | |||
| @@ -40,7 +40,7 @@ public sealed class BeamTests | |||
| var initial_tokens = context.Tokenize(prompt); | |||
| result.Append(prompt); | |||
| context.Eval(initial_tokens, 0); | |||
| context.Eval(initial_tokens.AsSpan(), 0); | |||
| NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | |||
| { | |||
| @@ -36,7 +36,7 @@ namespace LLama.Unittest | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| const string question = "Question. what is a cat?\nAnswer: "; | |||
| const string question = "Question. what is a cat?\nAnswer:"; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; | |||
| var timer = new Stopwatch(); | |||
| @@ -8,10 +8,12 @@ using System.IO; | |||
| using System.IO.MemoryMappedFiles; | |||
| using LLama.Common; | |||
| using System.Runtime.InteropServices; | |||
| using System.Threading.Tasks; | |||
| using LLama.Extensions; | |||
| using LLama.Abstractions; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| using System.Threading; | |||
| namespace LLama | |||
| { | |||
| @@ -344,16 +346,30 @@ namespace LLama | |||
| #region eval overloads | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| /// <param name="pastTokensCount"></param> | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(LLamaToken[] tokens, int pastTokensCount) | |||
| /// <param name="batch"></param> | |||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||
| /// - 0: success<br /> | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public int Decode(LLamaBatch batch) | |||
| { | |||
| return NativeHandle.Decode(batch); | |||
| } | |||
| /// <summary> | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||
| /// - 0: success<br /> | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) | |||
| { | |||
| return Eval(tokens.AsSpan(), pastTokensCount); | |||
| return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); | |||
| } | |||
| /// <summary> | |||
| @@ -363,7 +379,7 @@ namespace LLama | |||
| /// <param name="pastTokensCount"></param> | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| [Obsolete("use Decode() instead")] | |||
| public int Eval(List<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| @@ -394,20 +410,7 @@ namespace LLama | |||
| /// <param name="pastTokensCount"></param> | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| return Eval(tokens.Span, pastTokensCount); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| /// <param name="pastTokensCount"></param> | |||
| /// <returns>The updated `pastTokensCount`.</returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("use llama_decode() instead")] | |||
| [Obsolete("use Decode() instead")] | |||
| public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount) | |||
| { | |||
| var total = tokens.Length; | |||
| @@ -75,7 +75,7 @@ namespace LLama | |||
| // TODO(Rinne): deal with log of prompt | |||
| if (embed_inp_array.Length > 0) | |||
| Context.Eval(embed_inp_array, 0); | |||
| Context.Eval(embed_inp_array.AsSpan(), 0); | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null) | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama.Native; | |||
| @@ -7,27 +8,42 @@ namespace LLama.Native; | |||
| /// </summary> | |||
| public class LLamaBatch | |||
| { | |||
| private readonly byte[] _logits; | |||
| private byte[] _logits; | |||
| private readonly LLamaToken[] _tokens; | |||
| private readonly LLamaPos[] _positions; | |||
| private LLamaToken[] _tokens; | |||
| private LLamaPos[] _positions; | |||
| private readonly int[] _sequenceIdCount; | |||
| private readonly LLamaSeqId[][] _sequenceIds; | |||
| private readonly IntPtr[] _sequenceIdsPtrs; | |||
| private int[] _sequenceIdCount; | |||
| private LLamaSeqId[][] _sequenceIds; | |||
| private IntPtr[] _sequenceIdsPtrs; | |||
| /// <summary> | |||
| /// The number of tokens in this batch | |||
| /// </summary> | |||
| public int TokenCount { get; private set; } | |||
| /// <summary> | |||
| /// Maximum number of tokens that can be added to this batch | |||
| /// </summary> | |||
| private int TokenCapacity { get; set; } | |||
| /// <summary> | |||
| /// Maximum number of sequences a token can be assigned to | |||
| /// </summary> | |||
| public int MaxSequences { get; private set; } | |||
| /// <summary> | |||
| /// Create a new batch for submitting inputs to llama.cpp | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="n_seq_max"></param> | |||
| public LLamaBatch(int n_tokens, int n_seq_max) | |||
| /// <param name="n_seq_max">Max number of sequences a token can be assigned to</param> | |||
| public LLamaBatch(int n_seq_max) | |||
| { | |||
| // The number of tokens can be grown later, start off with a reasonable guess. | |||
| const int n_tokens = 64; | |||
| MaxSequences = n_seq_max; | |||
| TokenCapacity = n_tokens; | |||
| _logits = new byte[n_tokens]; | |||
| _tokens = new LLamaToken[n_tokens]; | |||
| _positions = new LLamaPos[n_tokens]; | |||
| @@ -37,7 +53,29 @@ public class LLamaBatch | |||
| _sequenceIds = new LLamaSeqId[n_tokens][]; | |||
| for (var i = 0; i < _sequenceIds.Length; i++) | |||
| _sequenceIds[i] = new LLamaSeqId[n_seq_max]; | |||
| _sequenceIds[i] = new LLamaSeqId[MaxSequences]; | |||
| } | |||
| private void Grow() | |||
| { | |||
| var n_tokens = TokenCount * 2; | |||
| TokenCapacity = n_tokens; | |||
| Array.Resize(ref _logits, n_tokens); | |||
| Array.Resize(ref _tokens, n_tokens); | |||
| Array.Resize(ref _positions, n_tokens); | |||
| Array.Resize(ref _sequenceIdCount, n_tokens); | |||
| Array.Resize(ref _sequenceIdsPtrs, n_tokens); | |||
| Array.Resize(ref _sequenceIds, n_tokens); | |||
| for (int i = 0; i < _sequenceIds.Length; i++) | |||
| { | |||
| // Growing the array filled elements with null, temporarily violating the nullability contract! | |||
| // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract | |||
| if (_sequenceIds[i] == null) | |||
| _sequenceIds[i] = new LLamaSeqId[MaxSequences]; | |||
| } | |||
| } | |||
| internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | |||
| @@ -79,8 +117,11 @@ public class LLamaBatch | |||
| /// <param name="pos">The position to add it att</param> | |||
| /// <param name="sequences">The set of sequences to add this token to</param> | |||
| /// <param name="logits"></param> | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| if (TokenCount == TokenCapacity) | |||
| Grow(); | |||
| _tokens[TokenCount] = token; | |||
| _positions[TokenCount] = pos; | |||
| @@ -101,20 +142,20 @@ public class LLamaBatch | |||
| /// <param name="pos">The position to add it att</param> | |||
| /// <param name="sequence">The sequence to add this token to</param> | |||
| /// <param name="logits"></param> | |||
| public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||
| public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||
| { | |||
| // Create a temporary span to contain 1 item without allocating | |||
| Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | |||
| sequences[0] = sequence; | |||
| // Add it | |||
| LLamaBatchAdd(token, pos, sequences, logits); | |||
| Add(token, pos, sequences, logits); | |||
| } | |||
| /// <summary> | |||
| /// Set TokenCount to zero for this batch | |||
| /// </summary> | |||
| public void LLamaBatchClear() | |||
| public void Clear() | |||
| { | |||
| TokenCount = 0; | |||
| } | |||