Browse Source

- Removed some unused `eval` methods.

- Added a `DecodeAsync` overload which runs the work in a task
 - Replaced some `NativeHandle` usage in `BatchedDecoding` with higher level equivalents.
 - Made the `LLamaBatch` grow when token capacity is exceeded, removing the need to manage token capacity externally.
tags/v0.10.0
Martin Evans 1 year ago
parent
commit
99969e538e
6 changed files with 90 additions and 46 deletions
  1. +6
    -6
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +1
    -1
      LLama.Unittest/BeamTests.cs
  3. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  4. +26
    -23
      LLama/LLamaContext.cs
  5. +1
    -1
      LLama/LLamaEmbedder.cs
  6. +55
    -14
      LLama/Native/LLamaBatch.cs

+ 6
- 6
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -52,13 +52,13 @@ public class BatchedDecoding
return; return;
} }


var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
var batch = new LLamaBatch(1);


// evaluate the initial prompt // evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++) for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);


if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{ {
await Console.Error.WriteLineAsync("llama_decode failed"); await Console.Error.WriteLineAsync("llama_decode failed");
return; return;
@@ -97,7 +97,7 @@ public class BatchedDecoding
timer.Start(); timer.Start();
while (n_cur <= n_len) while (n_cur <= n_len)
{ {
batch.LLamaBatchClear();
batch.Clear();


for (var i = 0; i < n_parallel; i++) for (var i = 0; i < n_parallel; i++)
{ {
@@ -129,7 +129,7 @@ public class BatchedDecoding
i_batch[i] = batch.TokenCount; i_batch[i] = batch.TokenCount;


// push this new token for next evaluation // push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);


n_decode++; n_decode++;
} }
@@ -143,7 +143,7 @@ public class BatchedDecoding
n_cur++; n_cur++;


// evaluate the current batch with the transformer model // evaluate the current batch with the transformer model
if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{ {
await Console.Error.WriteLineAsync("failed to eval"); await Console.Error.WriteLineAsync("failed to eval");
return; return;


+ 1
- 1
LLama.Unittest/BeamTests.cs View File

@@ -40,7 +40,7 @@ public sealed class BeamTests


var initial_tokens = context.Tokenize(prompt); var initial_tokens = context.Tokenize(prompt);
result.Append(prompt); result.Append(prompt);
context.Eval(initial_tokens, 0);
context.Eval(initial_tokens.AsSpan(), 0);


NativeApi.llama_beam_search(context.NativeHandle, (data, state) => NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{ {


+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -36,7 +36,7 @@ namespace LLama.Unittest


var executor = new StatelessExecutor(_weights, _params); var executor = new StatelessExecutor(_weights, _params);


const string question = "Question. what is a cat?\nAnswer: ";
const string question = "Question. what is a cat?\nAnswer:";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };


var timer = new Stopwatch(); var timer = new Stopwatch();


+ 26
- 23
LLama/LLamaContext.cs View File

@@ -8,10 +8,12 @@ using System.IO;
using System.IO.MemoryMappedFiles; using System.IO.MemoryMappedFiles;
using LLama.Common; using LLama.Common;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading.Tasks;
using LLama.Extensions; using LLama.Extensions;
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Sampling; using LLama.Sampling;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using System.Threading;


namespace LLama namespace LLama
{ {
@@ -344,16 +346,30 @@ namespace LLama


#region eval overloads #region eval overloads
/// <summary> /// <summary>
///
/// </summary> /// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(LLamaToken[] tokens, int pastTokensCount)
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{ {
return Eval(tokens.AsSpan(), pastTokensCount);
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
} }


/// <summary> /// <summary>
@@ -363,7 +379,7 @@ namespace LLama
/// <param name="pastTokensCount"></param> /// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns> /// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(List<LLamaToken> tokens, int pastTokensCount) public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
@@ -394,20 +410,7 @@ namespace LLama
/// <param name="pastTokensCount"></param> /// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns> /// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount) public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{ {
var total = tokens.Length; var total = tokens.Length;


+ 1
- 1
LLama/LLamaEmbedder.cs View File

@@ -75,7 +75,7 @@ namespace LLama
// TODO(Rinne): deal with log of prompt // TODO(Rinne): deal with log of prompt


if (embed_inp_array.Length > 0) if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array.AsSpan(), 0);


var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null) if (embeddings == null)


+ 55
- 14
LLama/Native/LLamaBatch.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;


namespace LLama.Native; namespace LLama.Native;


@@ -7,27 +8,42 @@ namespace LLama.Native;
/// </summary> /// </summary>
public class LLamaBatch public class LLamaBatch
{ {
private readonly byte[] _logits;
private byte[] _logits;


private readonly LLamaToken[] _tokens;
private readonly LLamaPos[] _positions;
private LLamaToken[] _tokens;
private LLamaPos[] _positions;


private readonly int[] _sequenceIdCount;
private readonly LLamaSeqId[][] _sequenceIds;
private readonly IntPtr[] _sequenceIdsPtrs;
private int[] _sequenceIdCount;
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;


/// <summary> /// <summary>
/// The number of tokens in this batch /// The number of tokens in this batch
/// </summary> /// </summary>
public int TokenCount { get; private set; } public int TokenCount { get; private set; }


/// <summary>
/// Maximum number of tokens that can be added to this batch
/// </summary>
private int TokenCapacity { get; set; }

/// <summary>
/// Maximum number of sequences a token can be assigned to
/// </summary>
public int MaxSequences { get; private set; }

/// <summary> /// <summary>
/// Create a new batch for submitting inputs to llama.cpp /// Create a new batch for submitting inputs to llama.cpp
/// </summary> /// </summary>
/// <param name="n_tokens"></param>
/// <param name="n_seq_max"></param>
public LLamaBatch(int n_tokens, int n_seq_max)
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param>
public LLamaBatch(int n_seq_max)
{ {
// The number of tokens can be grown later, start off with a reasonable guess.
const int n_tokens = 64;

MaxSequences = n_seq_max;
TokenCapacity = n_tokens;

_logits = new byte[n_tokens]; _logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens]; _tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens]; _positions = new LLamaPos[n_tokens];
@@ -37,7 +53,29 @@ public class LLamaBatch


_sequenceIds = new LLamaSeqId[n_tokens][]; _sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++) for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}

private void Grow()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;

Array.Resize(ref _logits, n_tokens);
Array.Resize(ref _tokens, n_tokens);
Array.Resize(ref _positions, n_tokens);

Array.Resize(ref _sequenceIdCount, n_tokens);
Array.Resize(ref _sequenceIdsPtrs, n_tokens);

Array.Resize(ref _sequenceIds, n_tokens);
for (int i = 0; i < _sequenceIds.Length; i++)
{
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}
} }


internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
@@ -79,8 +117,11 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param> /// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param> /// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param> /// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{ {
if (TokenCount == TokenCapacity)
Grow();

_tokens[TokenCount] = token; _tokens[TokenCount] = token;
_positions[TokenCount] = pos; _positions[TokenCount] = pos;


@@ -101,20 +142,20 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param> /// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param> /// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param> /// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{ {
// Create a temporary span to contain 1 item without allocating // Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence; sequences[0] = sequence;


// Add it // Add it
LLamaBatchAdd(token, pos, sequences, logits);
Add(token, pos, sequences, logits);
} }


/// <summary> /// <summary>
/// Set TokenCount to zero for this batch /// Set TokenCount to zero for this batch
/// </summary> /// </summary>
public void LLamaBatchClear()
public void Clear()
{ {
TokenCount = 0; TokenCount = 0;
} }

Loading…
Cancel
Save