Browse Source

Merge pull request #503 from martindevans/batched_executor_again

Introduced a new `BatchedExecutor`
tags/v0.10.0
Martin Evans GitHub 1 year ago
parent
commit
d03c1a9201
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
17 changed files with 912 additions and 204 deletions
  1. +0
    -172
      LLama.Examples/Examples/BatchedDecoding.cs
  2. +138
    -0
      LLama.Examples/Examples/BatchedExecutorFork.cs
  3. +121
    -0
      LLama.Examples/Examples/BatchedExecutorRewind.cs
  4. +2
    -1
      LLama.Examples/Examples/Runner.cs
  5. +119
    -0
      LLama/Batched/BatchedExecutor.cs
  6. +294
    -0
      LLama/Batched/Conversation.cs
  7. +59
    -0
      LLama/Batched/ConversationExtensions.cs
  8. +81
    -0
      LLama/Batched/Exceptions.cs
  9. +3
    -1
      LLama/LLamaContext.cs
  10. +1
    -0
      LLama/LLamaInstructExecutor.cs
  11. +1
    -0
      LLama/LLamaInteractExecutor.cs
  12. +44
    -11
      LLama/Native/LLamaBatch.cs
  13. +1
    -1
      LLama/Native/NativeApi.cs
  14. +4
    -4
      LLama/Native/SafeLLamaContextHandle.cs
  15. +7
    -12
      LLama/Sampling/BaseSamplingPipeline.cs
  16. +24
    -2
      LLama/Sampling/DefaultSamplingPipeline.cs
  17. +13
    -0
      LLama/Sampling/ISamplingPipeline.cs

+ 0
- 172
LLama.Examples/Examples/BatchedDecoding.cs View File

@@ -1,172 +0,0 @@
using System.Diagnostics;
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
/// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Load model
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt
var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;

// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);

var n_ctx = context.ContextSize;

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}

var batch = new LLamaBatch();

// evaluate the initial prompt
batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);

if (await context.DecodeAsync(batch) != DecodeResult.Ok)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}

if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.TokenCount - 1);

// Create per-stream decoder and sampler
var decoders = new StreamingTokenDecoder[n_parallel];
var samplers = new ISamplingPipeline[n_parallel];
for (var i = 0; i < n_parallel; i++)
{
decoders[i] = new StreamingTokenDecoder(context);
samplers[i] = new DefaultSamplingPipeline
{
Temperature = 0.1f + (float)i / n_parallel,
MinP = 0.25f,
};
}

var n_cur = batch.TokenCount;
var n_decode = 0;

var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
batch.Clear();

for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;

// Use the sampling pipeline to select a token
var new_token_id = samplers[i].Sample(
context.NativeHandle,
context.NativeHandle.GetLogitsIth(i_batch[i]),
Array.Empty<LLamaToken>()
);

// Finish this stream early if necessary
if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

// Add this token to the decoder, so it will be turned into text
decoders[i].Add(new_token_id);

i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);

n_decode++;
}

// Check if all streams are finished
if (batch.TokenCount == 0)
{
break;
}

n_cur++;

// evaluate the current batch with the transformer model
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}

timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in decoders)
{
var text = stream.Read();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}
}

+ 138
- 0
LLama.Examples/Examples/BatchedExecutorFork.cs View File

@@ -0,0 +1,138 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 64;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var start = executor.Prompt(prompt);
await executor.Infer();

// Create the root node of the tree
var root = new Node(start);

// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();

// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();

// Sample all active conversations
root.Sample();
}

Console.WriteLine($"{prompt}...");
root.Print(1);

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}

class Node
{
private readonly StreamingTokenDecoder _decoder;

private readonly DefaultSamplingPipeline _sampler;
private Conversation? _conversation;

private Node? _left;
private Node? _right;

public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;

public Node(Conversation conversation)
{
_sampler = new DefaultSamplingPipeline();
_conversation = conversation;
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
}

public void Sample()
{
if (_conversation == null)
{
_left?.Sample();
_right?.Sample();
return;
}

if (_conversation.RequiresInference)
return;

// Sample one token
var ctx = _conversation.Executor.Context.NativeHandle;
var logitsCopy = _conversation.Sample().ToArray();
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token);
_decoder.Add(token);

// Prompt the conversation with this token, to continue generating from there
_conversation.Prompt(token);
}

public void Split()
{
if (_conversation != null)
{
_left = new Node(_conversation.Fork());
_right = new Node(_conversation.Fork());

_conversation.Dispose();
_conversation = null;
}
else
{
_left?.Split();
_right?.Split();
}
}

public void Print(int indendation)
{
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
Console.ForegroundColor = colors[indendation % colors.Length];

var message = _decoder.Read().ReplaceLineEndings("");

var prefix = new string(' ', indendation * 3);
var suffix = _conversation == null ? "..." : "";
Console.WriteLine($"{prefix}...{message}{suffix}");

_left?.Print(indendation + 2);
_right?.Print(indendation + 2);
}
}
}

+ 121
- 0
LLama.Examples/Examples/BatchedExecutorRewind.cs View File

@@ -0,0 +1,121 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating tokens and then rewinding to an earlier state
/// </summary>
public class BatchedExecutorRewind
{
private const int n_generate = 24;
private const int n_rewind = 12;
private const int n_repeats = 6;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var conversation = executor.Prompt(prompt);
// Create the start node wrapping the conversation
var node = new Node(executor.Context);

// Print the prompt
Console.ForegroundColor = ConsoleColor.Green;
Console.WriteLine(prompt);

for (var i = 0; i < n_repeats; i++)
{
for (var j = 0; j < n_generate; j++)
{
// Run inference
await executor.Infer();

// Sample a token
var token = node.Sample(conversation);

// Continue conversation with this token
if (j != n_generate - 1)
conversation.Prompt(token);
}

// Write out what we generated
node.Write(n_rewind, i + 1);

// Rewind back a few tokens
conversation.Rewind(n_rewind + 1);

// Prompt with a token
conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));

// Create a new node around the rewound conversation
node = new Node(executor.Context);
}

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}

private class Node
{
private readonly LLamaContext _context;

private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
private readonly DefaultSamplingPipeline Sampler;

public Node(LLamaContext context)
{
_context = context;
Sampler = new DefaultSamplingPipeline();
}

public LLamaToken Sample(Conversation conversation)
{
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>());
_tokens.Add(token);
return token;
}

public void Write(int n_rewind, int depth)
{
var decoder = new StreamingTokenDecoder(_context);

for (var i = 0; i < _tokens.Count - n_rewind; i++)
decoder.Add(_tokens[i]);

Console.ForegroundColor = ConsoleColor.Green;
Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" "));

for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
decoder.Add(_tokens[i]);

Console.ForegroundColor = ConsoleColor.DarkRed;
Console.WriteLine(decoder.Read().ReplaceLineEndings(" "));
}

public LLamaToken GetToken(int index)
{
return _tokens[index];
}
}
}

+ 2
- 1
LLama.Examples/Examples/Runner.cs View File

@@ -23,7 +23,8 @@ public class Runner
{ "Semantic Kernel Chat.", SemanticKernelChat.Run },
{ "Semantic Kernel Memory.", SemanticKernelMemory.Run },
{ "Coding Assistant.", CodingAssistant.Run },
{ "Batch Decoding.", BatchedDecoding.Run },
{ "Batched Executor (Fork)", BatchedExecutorFork.Run },
{ "Batched Executor (Rewind)", BatchedExecutorRewind.Run },
{ "SK Kernel Memory.", KernelMemory.Run },
{ "Exit", async () => Environment.Exit(0) }
};


+ 119
- 0
LLama/Batched/BatchedExecutor.cs View File

@@ -0,0 +1,119 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Batched;

/// <summary>
/// A batched executor that can infer multiple separate "conversations" simultaneously.
/// </summary>
public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;

internal LLamaBatch Batch { get; }

/// <summary>
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
/// whether they're waiting for inference, or can be sampled.
/// </summary>
internal ulong Epoch { get; private set; }

/// <summary>
/// The <see cref="LLamaContext"/> this executor is using
/// </summary>
public LLamaContext Context { get; }

/// <summary>
/// The <see cref="LLamaWeights"/> this executor is using
/// </summary>
public LLamaWeights Model { get; }

/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchedTokenCount => Batch.TokenCount;

/// <summary>
/// Check if this executor has been disposed.
/// </summary>
public bool IsDisposed { get; private set; }

/// <summary>
/// Create a new batched executor
/// </summary>
/// <param name="model">The model to use</param>
/// <param name="contextParams">Parameters to create a new context</param>
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
{
Model = model;
Batch = new LLamaBatch();
Context = model.CreateContext(contextParams);
Epoch = 1;
}

~BatchedExecutor()
{
Dispose();
}

/// <summary>
/// Start a new <see cref="Conversation"/> with the given prompt
/// </summary>
/// <param name="prompt"></param>
/// <returns></returns>
public Conversation Prompt(string prompt)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var conversation = new Conversation(this, GetNextSequenceId(), 0);
conversation.Prompt(prompt);

return conversation;
}

/// <summary>
/// Run inference for all conversations in the batch which have pending tokens.
///
/// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
/// threads and running inference again.
/// </summary>
public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var status = await Context.DecodeAsync(Batch, cancellation);

// Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
// be called again after a warning (e.g. NoKvSlot).
if (status == DecodeResult.Ok)
{
Epoch++;
Batch.Clear();
}

return status;
}

/// <inheritdoc />
public void Dispose()
{
if (IsDisposed)
return;
IsDisposed = true;

GC.SuppressFinalize(this);

Context.Dispose();
}

internal LLamaSeqId GetNextSequenceId()
{
return checked((LLamaSeqId)_nextSequenceId++);
}
}

+ 294
- 0
LLama/Batched/Conversation.cs View File

@@ -0,0 +1,294 @@
using System;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Batched;

/// <summary>
/// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
/// </summary>
public sealed class Conversation
: IDisposable
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchIndex;
private bool _disposed;

/// <summary>
/// The executor which this conversation belongs to
/// </summary>
public BatchedExecutor Executor { get; }

/// <summary>
/// Unique ID for this conversation
/// </summary>
public LLamaSeqId ConversationId { get; }

/// <summary>
/// Total number of tokens in this conversation, cannot exceed the context length.
/// </summary>
public int TokenCount => _end.Value;

/// <summary>
/// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
/// </summary>
public bool IsDisposed => _disposed || Executor.IsDisposed;

/// <summary>
/// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
/// </summary>
public bool RequiresInference => _requiredEpoch > Executor.Epoch;

/// <summary>
/// Indicates that this conversation should be sampled.
/// </summary>
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;

#region construction/destruction
internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
{
ConversationId = id;
Executor = batch;

_end = end;
}

~Conversation()
{
Dispose();
}

/// <summary>
/// End this conversation, freeing all resources used by it
/// </summary>
/// <exception cref="ObjectDisposedException"></exception>
public void Dispose()
{
if (IsDisposed)
return;
_disposed = true;

// Remove this conversation from the KV cache
Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);

// Prevent finalizer from running
GC.SuppressFinalize(this);
}

private void AssertNotDisposed()
{
if (Executor.IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));
if (IsDisposed)
throw new ObjectDisposedException(nameof(Conversation));
}
#endregion

/// <summary>
/// Create a copy of the current conversation
/// </summary>
/// <remarks>The copy shares internal state, so consumes very little extra memory.</remarks>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public Conversation Fork()
{
AssertNotDisposed();

if (RequiresInference)
throw new CannotForkWhileRequiresInference();

// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
{
_batchIndex = _batchIndex,
_requiredEpoch = _requiredEpoch,
};

// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);

return c;
}

#region sample
/// <summary>
/// Get the logits from this conversation, ready for sampling
/// </summary>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
/// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
public ReadOnlySpan<float> Sample()
{
AssertNotDisposed();

if (_requiredEpoch < Executor.Epoch)
throw new CannotSampleRequiresPromptException();
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
}
#endregion

#region prompt
private void AssertCanBePrompted()
{
AssertNotDisposed();

if (RequiresInference)
throw new AlreadyPromptedConversationException();
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
public void Prompt(string input)
{
AssertCanBePrompted();

Prompt(Executor.Context.Tokenize(input));
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public void Prompt(IReadOnlyList<LLamaToken> tokens)
{
AssertCanBePrompted();

// Add the prompt to the batch
for (var i = 0; i < tokens.Count; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);

// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
}

/// <summary>
/// Add a single token to this conversation
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="InvalidOperationException"></exception>
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();

// Add this token as input
_batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);

// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
}
#endregion

#region modify
/// <summary>
/// Directly modify the KV cache of this conversation
/// </summary>
/// <param name="modifier"></param>
/// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
public void Modify(ModifyKvCache modifier)
{
AssertNotDisposed();

if (RequiresInference)
throw new CannotModifyWhileRequiresInference();

// do whatever the modification is
_end = modifier.Invoke(_end, new KvAccessor(this));

// Set the epoch down to zero, this ensures that this conversation
// cannot be sampled until it is prompted again.
_requiredEpoch = 0;
}

/// <summary>
/// Provides direct access to the KV cache of a <see cref="Conversation"/>.
/// See <see cref="Modify"/> for how to use this.
/// </summary>
public readonly ref struct KvAccessor
{
private readonly Conversation _conversation;

internal KvAccessor(Conversation conversation)
{
_conversation = conversation;
}

#region remove
/// <summary>
/// Removes all tokens that have positions in [start, end)
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
public void Remove(LLamaPos start, LLamaPos end)
{
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}

/// <summary>
/// Removes all tokens starting from the given position
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="count">Number of tokens</param>
public void Remove(LLamaPos start, int count)
{
if (count <= 0)
return;

var end = start.Value + count;
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
}
#endregion

#region shift
/// <summary>
/// Adds relative position "delta" to all tokens that have positions in [p0, p1).
/// If the KV cache is RoPEd, the KV data is updated
/// accordingly
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
/// <param name="delta">Amount to add on to each token position</param>
public void Shift(LLamaPos start, LLamaPos end, int delta)
{
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
}
#endregion

#region divide
/// <summary>
/// Integer division of the positions by factor of `d > 1`.
/// If the KV cache is RoPEd, the KV data is updated accordingly.
/// </summary>
/// <param name="start">Start position (inclusive). If less than zero, it is clamped to zero.</param>
/// <param name="end">End position (exclusive). If less than zero, it is treated as "infinity".</param>
/// <param name="divisor">Amount to divide each position by.</param>
public void Divide(LLamaPos start, LLamaPos end, int divisor)
{
if (divisor <= 0)
throw new ArgumentOutOfRangeException(nameof(divisor));

_conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
}
#endregion
}

/// <summary>
/// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly
/// </summary>
/// <param name="end">The current end token of this conversation</param>
/// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param>
/// <returns>The new end token position</returns>
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion
}

+ 59
- 0
LLama/Batched/ConversationExtensions.cs View File

@@ -0,0 +1,59 @@
using System;

namespace LLama.Batched;

/// <summary>
/// Extension method for <see cref="Conversation"/>
/// </summary>
public static class ConversationExtensions
{
/// <summary>
/// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end
/// </summary>
/// <param name="conversation">The conversation to rewind</param>
/// <param name="tokens">The number of tokens to rewind</param>
/// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than TokenCount</exception>
public static void Rewind(this Conversation conversation, int tokens)
{
if (tokens > conversation.TokenCount)
throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens");

conversation.Modify((end, kv) =>
{
// Remove those tokens from KV
kv.Remove(end.Value - tokens, tokens);

// Return adjusted end position
return end.Value - tokens;
});
}

/// <summary>
/// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over.
/// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context
/// gets full, keeping the prompt at the start intact.
/// </summary>
/// <param name="conversation">The conversation to rewind</param>
/// <param name="count">How much to shift tokens over by</param>
/// <param name="keep">The number of tokens at the start which should <b>not</b> be shifted</param>
public static void ShiftLeft(this Conversation conversation, int count, int keep)
{
// Given a setup like this (shift=5, keep=3):
//
// AAABBBBBCCCCCCCCC...
//
// We want to remove all the B's, shift all the C's and leave all the A's untouched

conversation.Modify((end, kv) =>
{
// Remove the B's
kv.Remove(keep, count);

// Shift the C's
kv.Shift(keep + count, end, -count);

// Update total count
return end.Value - count;
});
}
}

+ 81
- 0
LLama/Batched/Exceptions.cs View File

@@ -0,0 +1,81 @@
using System;

namespace LLama.Batched;

/// <summary>
/// Base class for exceptions thrown from <see cref="BatchedExecutor"/>
/// </summary>
public class ExperimentalBatchedExecutorException
: Exception
{
internal ExperimentalBatchedExecutorException(string message)
: base(message)
{
}
}

/// <summary>
/// This exception is thrown when "Prompt()" is called on a <see cref="Conversation"/> which has
/// already been prompted and before "Infer()" has been called on the associated
/// <see cref="BatchedExecutor"/>.
/// </summary>
public class AlreadyPromptedConversationException
: ExperimentalBatchedExecutorException
{
internal AlreadyPromptedConversationException()
: base("Must call `Infer()` before prompting this Conversation again")
{
}
}

/// <summary>
/// This exception is thrown when "Sample()" is called on a <see cref="Conversation"/> which has
/// already been prompted and before "Infer()" has been called on the associated
/// <see cref="BatchedExecutor"/>.
/// </summary>
public class CannotSampleRequiresInferenceException
: ExperimentalBatchedExecutorException
{
internal CannotSampleRequiresInferenceException()
: base("Must call `Infer()` before sampling from this Conversation")
{
}
}

/// <summary>
/// This exception is thrown when "Sample()" is called on a <see cref="Conversation"/> which was not
/// first prompted.
/// <see cref="BatchedExecutor"/>.
/// </summary>
public class CannotSampleRequiresPromptException
: ExperimentalBatchedExecutorException
{
internal CannotSampleRequiresPromptException()
: base("Must call `Prompt()` and then `Infer()` before sampling from this Conversation")
{
}
}

/// <summary>
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotForkWhileRequiresInference
: ExperimentalBatchedExecutorException
{
internal CannotForkWhileRequiresInference()
: base("Cannot `Fork()` a conversation while RequiresInference is true")
{
}
}

/// <summary>
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotModifyWhileRequiresInference
: ExperimentalBatchedExecutorException
{
internal CannotModifyWhileRequiresInference()
: base("Cannot `Modify()` a conversation while RequiresInference is true")
{
}
}

+ 3
- 1
LLama/LLamaContext.cs View File

@@ -221,7 +221,9 @@ namespace LLama
/// <returns>The selected token</returns>
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
pipeline.Accept(NativeHandle, token);
return token;
}

/// <summary>


+ 1
- 0
LLama/LLamaInstructExecutor.cs View File

@@ -213,6 +213,7 @@ namespace LLama
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{


+ 1
- 0
LLama/LLamaInteractExecutor.cs View File

@@ -192,6 +192,7 @@ namespace LLama
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{


+ 44
- 11
LLama/Native/LLamaBatch.cs View File

@@ -18,6 +18,11 @@ public class LLamaBatch
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;

/// <summary>
/// Keep track of the index of existing token/position combos in the batch
/// </summary>
private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new();

/// <summary>
/// The number of tokens in this batch
/// </summary>
@@ -130,23 +135,44 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
// Try to find this (token, position) combo somewhere in the batch to re-use it
if (_index.TryGetValue((token, pos), out var existingIndex))
{
if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);

foreach (var sequence in sequences)
{
_sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence;
_sequenceIdCount[existingIndex]++;
}

return existingIndex;
}

// Couldn't find this it in the batch, add a new item

// Frow capacity as necessary
if (TokenCount == TokenCapacity)
GrowTokenCapacity();
if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length);

// Store the position in the index, so it can be found later
_index.Add((token, pos), TokenCount);

// Add the items to the arrays
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;

_sequenceIdCount[TokenCount] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
_sequenceIds[TokenCount][i] = sequences[i];

_logits[TokenCount] = Convert.ToByte(logits);

TokenCount++;
return TokenCount++;
}

/// <summary>
@@ -157,11 +183,12 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
{
#if NET5_0_OR_GREATER
var seqSpan = CollectionsMarshal.AsSpan(sequences);
Add(token, pos, seqSpan, logits);
return Add(token, pos, seqSpan, logits);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
@@ -171,7 +198,7 @@ public class LLamaBatch
try
{
sequences.CopyTo(rented, 0);
Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
return Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
@@ -188,14 +215,15 @@ public class LLamaBatch
/// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param>
public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;

// Add it
Add(token, pos, sequences, logits);
return Add(token, pos, sequences, logits);
}

/// <summary>
@@ -205,13 +233,17 @@ public class LLamaBatch
/// <param name="start">The starting position to add tokens at</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logitsLast">Whether the final token should generate logits</param>
public void AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
/// <returns>The index that the final token was added at. Use this for GetLogitsIth</returns>
public int AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
{
var last = -1;
for (var i = 0; i < tokens.Length; i++)
{
var logits = (i == tokens.Length - 1) & logitsLast;
Add(tokens[i], start.Value + i, sequence, logits);
last = Add(tokens[i], start.Value + i, sequence, logits);
}

return last;
}
#endregion

@@ -221,5 +253,6 @@ public class LLamaBatch
public void Clear()
{
TokenCount = 0;
_index.Clear();
}
}

+ 1
- 1
LLama/Native/NativeApi.cs View File

@@ -388,7 +388,7 @@ namespace LLama.Native
/// <param name="p1"></param>
/// <param name="delta"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);

/// <summary>
/// Integer division of the positions by factor of `d > 1`


+ 4
- 4
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -369,15 +369,15 @@ namespace LLama.Native
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta)
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}

/// <summary>
/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// p0 &lt; 0 : [0, p1]
/// Integer division of the positions by factor of `d > 1`.
/// If the KV cache is RoPEd, the KV data is updated accordingly.<br />
/// p0 &lt; 0 : [0, p1]<br />
/// p1 &lt; 0 : [p0, inf)
/// </summary>
/// <param name="seq"></param>


+ 7
- 12
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -40,10 +40,7 @@ public abstract class BaseSamplingPipeline
var candidates = LLamaTokenDataArray.Create(logits);

// Process token data array
ProcessTokenDataArray(ctx, candidates, lastTokens);

// Choose the final value
return ChooseToken(ctx, candidates);
return ProcessTokenDataArray(ctx, candidates, lastTokens);
}
finally
{
@@ -53,6 +50,9 @@ public abstract class BaseSamplingPipeline
}
}

/// <inheritdoc />
public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token);

#region protected tokens
/// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
@@ -107,19 +107,14 @@ public abstract class BaseSamplingPipeline
/// <returns></returns>
protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Choose the final token from the candidates
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <returns></returns>
protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);

/// <inheritdoc/>
public virtual void Reset()
{
}

/// <inheritdoc />
public abstract ISamplingPipeline Clone();

/// <inheritdoc/>
public virtual void Dispose()
{


+ 24
- 2
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -141,9 +141,31 @@ public sealed class DefaultSamplingPipeline
return id;
}

public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
Grammar?.AcceptToken(ctx, token);
}

/// <inheritdoc />
protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
public override ISamplingPipeline Clone()
{
return candidates.SampleToken(ctx);
var clone = new DefaultSamplingPipeline();

foreach (var (k, v) in LogitBias)
clone.LogitBias.Add(k, v);

clone.Grammar = Grammar?.Clone();
clone.RepeatPenalty = RepeatPenalty;
clone.AlphaFrequency = AlphaFrequency;
clone.AlphaPresence = AlphaPresence;
clone.Temperature = Temperature;
clone.TopK = TopK;
clone.TailFreeZ = TailFreeZ;
clone.TypicalP = TypicalP;
clone.TopP = TopP;
clone.MinP = MinP;
clone.PenalizeNewline = PenalizeNewline;

return clone;
}
}

+ 13
- 0
LLama/Sampling/ISamplingPipeline.cs View File

@@ -21,10 +21,23 @@ public interface ISamplingPipeline
/// <returns></returns>
LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Update the pipeline, with knowledge that a particular token was just accepted
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
void Accept(SafeLLamaContextHandle ctx, LLamaToken token);

/// <summary>
/// Reset all internal state of the sampling pipeline
/// </summary>
void Reset();

/// <summary>
/// Create a copy of this sampling pipeline
/// </summary>
/// <returns></returns>
ISamplingPipeline Clone();
}

/// <summary>


Loading…
Cancel
Save