Browse Source

- Removed some unused `eval` methods.

- Added a `DecodeAsync` overload which runs the work in a task
 - Replaced some `NativeHandle` usage in `BatchedDecoding` with higher level equivalents.
 - Made the `LLamaBatch` grow when token capacity is exceeded, removing the need to manage token capacity externally.
tags/v0.10.0
Martin Evans 1 year ago
parent
commit
99969e538e
6 changed files with 90 additions and 46 deletions
  1. +6
    -6
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +1
    -1
      LLama.Unittest/BeamTests.cs
  3. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  4. +26
    -23
      LLama/LLamaContext.cs
  5. +1
    -1
      LLama/LLamaEmbedder.cs
  6. +55
    -14
      LLama/Native/LLamaBatch.cs

+ 6
- 6
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -52,13 +52,13 @@ public class BatchedDecoding
return;
}

var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
var batch = new LLamaBatch(1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);

if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
@@ -97,7 +97,7 @@ public class BatchedDecoding
timer.Start();
while (n_cur <= n_len)
{
batch.LLamaBatchClear();
batch.Clear();

for (var i = 0; i < n_parallel; i++)
{
@@ -129,7 +129,7 @@ public class BatchedDecoding
i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);

n_decode++;
}
@@ -143,7 +143,7 @@ public class BatchedDecoding
n_cur++;

// evaluate the current batch with the transformer model
if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;


+ 1
- 1
LLama.Unittest/BeamTests.cs View File

@@ -40,7 +40,7 @@ public sealed class BeamTests

var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens, 0);
context.Eval(initial_tokens.AsSpan(), 0);

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{


+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -36,7 +36,7 @@ namespace LLama.Unittest

var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
const string question = "Question. what is a cat?\nAnswer:";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };

var timer = new Stopwatch();


+ 26
- 23
LLama/LLamaContext.cs View File

@@ -8,10 +8,12 @@ using System.IO;
using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
using System.Threading;

namespace LLama
{
@@ -344,16 +346,30 @@ namespace LLama

#region eval overloads
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(LLamaToken[] tokens, int pastTokensCount)
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Eval(tokens.AsSpan(), pastTokensCount);
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
}

/// <summary>
@@ -363,7 +379,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
@@ -394,20 +410,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;


+ 1
- 1
LLama/LLamaEmbedder.cs View File

@@ -75,7 +75,7 @@ namespace LLama
// TODO(Rinne): deal with log of prompt

if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array.AsSpan(), 0);

var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)


+ 55
- 14
LLama/Native/LLamaBatch.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;

namespace LLama.Native;

@@ -7,27 +8,42 @@ namespace LLama.Native;
/// </summary>
public class LLamaBatch
{
private readonly byte[] _logits;
private byte[] _logits;

private readonly LLamaToken[] _tokens;
private readonly LLamaPos[] _positions;
private LLamaToken[] _tokens;
private LLamaPos[] _positions;

private readonly int[] _sequenceIdCount;
private readonly LLamaSeqId[][] _sequenceIds;
private readonly IntPtr[] _sequenceIdsPtrs;
private int[] _sequenceIdCount;
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;

/// <summary>
/// The number of tokens in this batch
/// </summary>
public int TokenCount { get; private set; }

/// <summary>
/// Maximum number of tokens that can be added to this batch
/// </summary>
private int TokenCapacity { get; set; }

/// <summary>
/// Maximum number of sequences a token can be assigned to
/// </summary>
public int MaxSequences { get; private set; }

/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="n_seq_max"></param>
public LLamaBatch(int n_tokens, int n_seq_max)
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param>
public LLamaBatch(int n_seq_max)
{
// The number of tokens can be grown later, start off with a reasonable guess.
const int n_tokens = 64;

MaxSequences = n_seq_max;
TokenCapacity = n_tokens;

_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
@@ -37,7 +53,29 @@ public class LLamaBatch

_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}

private void Grow()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;

Array.Resize(ref _logits, n_tokens);
Array.Resize(ref _tokens, n_tokens);
Array.Resize(ref _positions, n_tokens);

Array.Resize(ref _sequenceIdCount, n_tokens);
Array.Resize(ref _sequenceIdsPtrs, n_tokens);

Array.Resize(ref _sequenceIds, n_tokens);
for (int i = 0; i < _sequenceIds.Length; i++)
{
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}
}

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
@@ -79,8 +117,11 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
if (TokenCount == TokenCapacity)
Grow();

_tokens[TokenCount] = token;
_positions[TokenCount] = pos;

@@ -101,20 +142,20 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;

// Add it
LLamaBatchAdd(token, pos, sequences, logits);
Add(token, pos, sequences, logits);
}

/// <summary>
/// Set TokenCount to zero for this batch
/// </summary>
public void LLamaBatchClear()
public void Clear()
{
TokenCount = 0;
}

Loading…
Cancel
Save