Introduced a new `BatchedExecutor`tags/v0.10.0
| @@ -1,172 +0,0 @@ | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using LLama.Sampling; | |||||
| namespace LLama.Examples.Examples; | |||||
| /// <summary> | |||||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||||
| /// </summary> | |||||
| /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks> | |||||
| public class BatchedDecoding | |||||
| { | |||||
| private const int n_parallel = 8; | |||||
| private const int n_len = 32; | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| Console.WriteLine("Prompt (leave blank to select automatically):"); | |||||
| var prompt = Console.ReadLine(); | |||||
| if (string.IsNullOrWhiteSpace(prompt)) | |||||
| prompt = "Not many people know that"; | |||||
| // Load model | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| // Tokenize prompt | |||||
| var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8); | |||||
| var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; | |||||
| // Create a context | |||||
| parameters.ContextSize = (uint)model.ContextSize; | |||||
| parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var n_ctx = context.ContextSize; | |||||
| // make sure the KV cache is big enough to hold all the prompt and generated tokens | |||||
| if (n_kv_req > n_ctx) | |||||
| { | |||||
| await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); | |||||
| await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); | |||||
| return; | |||||
| } | |||||
| var batch = new LLamaBatch(); | |||||
| // evaluate the initial prompt | |||||
| batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true); | |||||
| if (await context.DecodeAsync(batch) != DecodeResult.Ok) | |||||
| { | |||||
| await Console.Error.WriteLineAsync("llama_decode failed"); | |||||
| return; | |||||
| } | |||||
| // assign the system KV cache to all parallel sequences | |||||
| // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | |||||
| for (var i = 1; i < n_parallel; ++i) | |||||
| { | |||||
| context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); | |||||
| } | |||||
| if (n_parallel > 1) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| Console.WriteLine($"generating {n_parallel} sequences..."); | |||||
| } | |||||
| // remember the batch index of the last token for each parallel sequence | |||||
| // we need this to determine which logits to sample from | |||||
| List<int> i_batch = new(); | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| i_batch.Add(batch.TokenCount - 1); | |||||
| // Create per-stream decoder and sampler | |||||
| var decoders = new StreamingTokenDecoder[n_parallel]; | |||||
| var samplers = new ISamplingPipeline[n_parallel]; | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| { | |||||
| decoders[i] = new StreamingTokenDecoder(context); | |||||
| samplers[i] = new DefaultSamplingPipeline | |||||
| { | |||||
| Temperature = 0.1f + (float)i / n_parallel, | |||||
| MinP = 0.25f, | |||||
| }; | |||||
| } | |||||
| var n_cur = batch.TokenCount; | |||||
| var n_decode = 0; | |||||
| var timer = new Stopwatch(); | |||||
| timer.Start(); | |||||
| while (n_cur <= n_len) | |||||
| { | |||||
| batch.Clear(); | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| { | |||||
| // Skip completed streams | |||||
| if (i_batch[i] < 0) | |||||
| continue; | |||||
| // Use the sampling pipeline to select a token | |||||
| var new_token_id = samplers[i].Sample( | |||||
| context.NativeHandle, | |||||
| context.NativeHandle.GetLogitsIth(i_batch[i]), | |||||
| Array.Empty<LLamaToken>() | |||||
| ); | |||||
| // Finish this stream early if necessary | |||||
| if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken) | |||||
| { | |||||
| i_batch[i] = -1; | |||||
| Console.WriteLine($"Completed Stream {i} early"); | |||||
| continue; | |||||
| } | |||||
| // Add this token to the decoder, so it will be turned into text | |||||
| decoders[i].Add(new_token_id); | |||||
| i_batch[i] = batch.TokenCount; | |||||
| // push this new token for next evaluation | |||||
| batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true); | |||||
| n_decode++; | |||||
| } | |||||
| // Check if all streams are finished | |||||
| if (batch.TokenCount == 0) | |||||
| { | |||||
| break; | |||||
| } | |||||
| n_cur++; | |||||
| // evaluate the current batch with the transformer model | |||||
| if (await context.DecodeAsync(batch) != 0) | |||||
| { | |||||
| await Console.Error.WriteLineAsync("failed to eval"); | |||||
| return; | |||||
| } | |||||
| } | |||||
| timer.Stop(); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||||
| Console.WriteLine(); | |||||
| Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); | |||||
| Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); | |||||
| var index = 0; | |||||
| foreach (var stream in decoders) | |||||
| { | |||||
| var text = stream.Read(); | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| Console.Write($"{index++}. {prompt}"); | |||||
| Console.ForegroundColor = ConsoleColor.Red; | |||||
| Console.WriteLine(text); | |||||
| } | |||||
| Console.WriteLine("Press any key to exit demo"); | |||||
| Console.ReadKey(true); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,138 @@ | |||||
| using LLama.Batched; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using LLama.Sampling; | |||||
| namespace LLama.Examples.Examples; | |||||
| /// <summary> | |||||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||||
| /// </summary> | |||||
| public class BatchedExecutorFork | |||||
| { | |||||
| private const int n_split = 16; | |||||
| private const int n_len = 64; | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| Console.WriteLine("Prompt (leave blank to select automatically):"); | |||||
| var prompt = Console.ReadLine(); | |||||
| if (string.IsNullOrWhiteSpace(prompt)) | |||||
| prompt = "Not many people know that"; | |||||
| // Create an executor that can evaluate a batch of conversations together | |||||
| var executor = new BatchedExecutor(model, parameters); | |||||
| // Print some info | |||||
| var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); | |||||
| Console.WriteLine($"Created executor with model: {name}"); | |||||
| // Evaluate the initial prompt to create one conversation | |||||
| var start = executor.Prompt(prompt); | |||||
| await executor.Infer(); | |||||
| // Create the root node of the tree | |||||
| var root = new Node(start); | |||||
| // Run inference loop | |||||
| for (var i = 0; i < n_len; i++) | |||||
| { | |||||
| if (i != 0) | |||||
| await executor.Infer(); | |||||
| // Occasionally fork all the active conversations | |||||
| if (i != 0 && i % n_split == 0) | |||||
| root.Split(); | |||||
| // Sample all active conversations | |||||
| root.Sample(); | |||||
| } | |||||
| Console.WriteLine($"{prompt}..."); | |||||
| root.Print(1); | |||||
| Console.WriteLine("Press any key to exit demo"); | |||||
| Console.ReadKey(true); | |||||
| } | |||||
| class Node | |||||
| { | |||||
| private readonly StreamingTokenDecoder _decoder; | |||||
| private readonly DefaultSamplingPipeline _sampler; | |||||
| private Conversation? _conversation; | |||||
| private Node? _left; | |||||
| private Node? _right; | |||||
| public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount; | |||||
| public Node(Conversation conversation) | |||||
| { | |||||
| _sampler = new DefaultSamplingPipeline(); | |||||
| _conversation = conversation; | |||||
| _decoder = new StreamingTokenDecoder(conversation.Executor.Context); | |||||
| } | |||||
| public void Sample() | |||||
| { | |||||
| if (_conversation == null) | |||||
| { | |||||
| _left?.Sample(); | |||||
| _right?.Sample(); | |||||
| return; | |||||
| } | |||||
| if (_conversation.RequiresInference) | |||||
| return; | |||||
| // Sample one token | |||||
| var ctx = _conversation.Executor.Context.NativeHandle; | |||||
| var logitsCopy = _conversation.Sample().ToArray(); | |||||
| var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>()); | |||||
| _sampler.Accept(ctx, token); | |||||
| _decoder.Add(token); | |||||
| // Prompt the conversation with this token, to continue generating from there | |||||
| _conversation.Prompt(token); | |||||
| } | |||||
| public void Split() | |||||
| { | |||||
| if (_conversation != null) | |||||
| { | |||||
| _left = new Node(_conversation.Fork()); | |||||
| _right = new Node(_conversation.Fork()); | |||||
| _conversation.Dispose(); | |||||
| _conversation = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| _left?.Split(); | |||||
| _right?.Split(); | |||||
| } | |||||
| } | |||||
| public void Print(int indendation) | |||||
| { | |||||
| var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White }; | |||||
| Console.ForegroundColor = colors[indendation % colors.Length]; | |||||
| var message = _decoder.Read().ReplaceLineEndings(""); | |||||
| var prefix = new string(' ', indendation * 3); | |||||
| var suffix = _conversation == null ? "..." : ""; | |||||
| Console.WriteLine($"{prefix}...{message}{suffix}"); | |||||
| _left?.Print(indendation + 2); | |||||
| _right?.Print(indendation + 2); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,121 @@ | |||||
| using LLama.Batched; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using LLama.Sampling; | |||||
| namespace LLama.Examples.Examples; | |||||
| /// <summary> | |||||
| /// This demonstrates generating tokens and then rewinding to an earlier state | |||||
| /// </summary> | |||||
| public class BatchedExecutorRewind | |||||
| { | |||||
| private const int n_generate = 24; | |||||
| private const int n_rewind = 12; | |||||
| private const int n_repeats = 6; | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| var modelPath = Console.ReadLine(); | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| Console.WriteLine("Prompt (leave blank to select automatically):"); | |||||
| var prompt = Console.ReadLine(); | |||||
| if (string.IsNullOrWhiteSpace(prompt)) | |||||
| prompt = "Not many people know that"; | |||||
| // Create an executor that can evaluate a batch of conversations together | |||||
| var executor = new BatchedExecutor(model, parameters); | |||||
| // Print some info | |||||
| var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); | |||||
| Console.WriteLine($"Created executor with model: {name}"); | |||||
| // Evaluate the initial prompt to create one conversation | |||||
| var conversation = executor.Prompt(prompt); | |||||
| // Create the start node wrapping the conversation | |||||
| var node = new Node(executor.Context); | |||||
| // Print the prompt | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| Console.WriteLine(prompt); | |||||
| for (var i = 0; i < n_repeats; i++) | |||||
| { | |||||
| for (var j = 0; j < n_generate; j++) | |||||
| { | |||||
| // Run inference | |||||
| await executor.Infer(); | |||||
| // Sample a token | |||||
| var token = node.Sample(conversation); | |||||
| // Continue conversation with this token | |||||
| if (j != n_generate - 1) | |||||
| conversation.Prompt(token); | |||||
| } | |||||
| // Write out what we generated | |||||
| node.Write(n_rewind, i + 1); | |||||
| // Rewind back a few tokens | |||||
| conversation.Rewind(n_rewind + 1); | |||||
| // Prompt with a token | |||||
| conversation.Prompt(node.GetToken(n_generate - n_rewind - 1)); | |||||
| // Create a new node around the rewound conversation | |||||
| node = new Node(executor.Context); | |||||
| } | |||||
| Console.WriteLine("Press any key to exit demo"); | |||||
| Console.ReadKey(true); | |||||
| } | |||||
| private class Node | |||||
| { | |||||
| private readonly LLamaContext _context; | |||||
| private readonly List<LLamaToken> _tokens = new List<LLamaToken>(); | |||||
| private readonly DefaultSamplingPipeline Sampler; | |||||
| public Node(LLamaContext context) | |||||
| { | |||||
| _context = context; | |||||
| Sampler = new DefaultSamplingPipeline(); | |||||
| } | |||||
| public LLamaToken Sample(Conversation conversation) | |||||
| { | |||||
| var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>()); | |||||
| _tokens.Add(token); | |||||
| return token; | |||||
| } | |||||
| public void Write(int n_rewind, int depth) | |||||
| { | |||||
| var decoder = new StreamingTokenDecoder(_context); | |||||
| for (var i = 0; i < _tokens.Count - n_rewind; i++) | |||||
| decoder.Add(_tokens[i]); | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")); | |||||
| for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) | |||||
| decoder.Add(_tokens[i]); | |||||
| Console.ForegroundColor = ConsoleColor.DarkRed; | |||||
| Console.WriteLine(decoder.Read().ReplaceLineEndings(" ")); | |||||
| } | |||||
| public LLamaToken GetToken(int index) | |||||
| { | |||||
| return _tokens[index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,7 +23,8 @@ public class Runner | |||||
| { "Semantic Kernel Chat.", SemanticKernelChat.Run }, | { "Semantic Kernel Chat.", SemanticKernelChat.Run }, | ||||
| { "Semantic Kernel Memory.", SemanticKernelMemory.Run }, | { "Semantic Kernel Memory.", SemanticKernelMemory.Run }, | ||||
| { "Coding Assistant.", CodingAssistant.Run }, | { "Coding Assistant.", CodingAssistant.Run }, | ||||
| { "Batch Decoding.", BatchedDecoding.Run }, | |||||
| { "Batched Executor (Fork)", BatchedExecutorFork.Run }, | |||||
| { "Batched Executor (Rewind)", BatchedExecutorRewind.Run }, | |||||
| { "SK Kernel Memory.", KernelMemory.Run }, | { "SK Kernel Memory.", KernelMemory.Run }, | ||||
| { "Exit", async () => Environment.Exit(0) } | { "Exit", async () => Environment.Exit(0) } | ||||
| }; | }; | ||||
| @@ -0,0 +1,119 @@ | |||||
| using System; | |||||
| using System.Threading; | |||||
| using System.Threading.Tasks; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Batched; | |||||
| /// <summary> | |||||
| /// A batched executor that can infer multiple separate "conversations" simultaneously. | |||||
| /// </summary> | |||||
| public sealed class BatchedExecutor | |||||
| : IDisposable | |||||
| { | |||||
| private int _nextSequenceId; | |||||
| internal LLamaBatch Batch { get; } | |||||
| /// <summary> | |||||
| /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of | |||||
| /// whether they're waiting for inference, or can be sampled. | |||||
| /// </summary> | |||||
| internal ulong Epoch { get; private set; } | |||||
| /// <summary> | |||||
| /// The <see cref="LLamaContext"/> this executor is using | |||||
| /// </summary> | |||||
| public LLamaContext Context { get; } | |||||
| /// <summary> | |||||
| /// The <see cref="LLamaWeights"/> this executor is using | |||||
| /// </summary> | |||||
| public LLamaWeights Model { get; } | |||||
| /// <summary> | |||||
| /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called | |||||
| /// </summary> | |||||
| public int BatchedTokenCount => Batch.TokenCount; | |||||
| /// <summary> | |||||
| /// Check if this executor has been disposed. | |||||
| /// </summary> | |||||
| public bool IsDisposed { get; private set; } | |||||
| /// <summary> | |||||
| /// Create a new batched executor | |||||
| /// </summary> | |||||
| /// <param name="model">The model to use</param> | |||||
| /// <param name="contextParams">Parameters to create a new context</param> | |||||
| public BatchedExecutor(LLamaWeights model, IContextParams contextParams) | |||||
| { | |||||
| Model = model; | |||||
| Batch = new LLamaBatch(); | |||||
| Context = model.CreateContext(contextParams); | |||||
| Epoch = 1; | |||||
| } | |||||
| ~BatchedExecutor() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Start a new <see cref="Conversation"/> with the given prompt | |||||
| /// </summary> | |||||
| /// <param name="prompt"></param> | |||||
| /// <returns></returns> | |||||
| public Conversation Prompt(string prompt) | |||||
| { | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(BatchedExecutor)); | |||||
| var conversation = new Conversation(this, GetNextSequenceId(), 0); | |||||
| conversation.Prompt(prompt); | |||||
| return conversation; | |||||
| } | |||||
| /// <summary> | |||||
| /// Run inference for all conversations in the batch which have pending tokens. | |||||
| /// | |||||
| /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation | |||||
| /// threads and running inference again. | |||||
| /// </summary> | |||||
| public async Task<DecodeResult> Infer(CancellationToken cancellation = default) | |||||
| { | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(BatchedExecutor)); | |||||
| var status = await Context.DecodeAsync(Batch, cancellation); | |||||
| // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can | |||||
| // be called again after a warning (e.g. NoKvSlot). | |||||
| if (status == DecodeResult.Ok) | |||||
| { | |||||
| Epoch++; | |||||
| Batch.Clear(); | |||||
| } | |||||
| return status; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public void Dispose() | |||||
| { | |||||
| if (IsDisposed) | |||||
| return; | |||||
| IsDisposed = true; | |||||
| GC.SuppressFinalize(this); | |||||
| Context.Dispose(); | |||||
| } | |||||
| internal LLamaSeqId GetNextSequenceId() | |||||
| { | |||||
| return checked((LLamaSeqId)_nextSequenceId++); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,294 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Native; | |||||
| namespace LLama.Batched; | |||||
| /// <summary> | |||||
| /// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM) | |||||
| /// </summary> | |||||
| public sealed class Conversation | |||||
| : IDisposable | |||||
| { | |||||
| private ulong _requiredEpoch; | |||||
| private LLamaPos _end; | |||||
| private int _batchIndex; | |||||
| private bool _disposed; | |||||
| /// <summary> | |||||
| /// The executor which this conversation belongs to | |||||
| /// </summary> | |||||
| public BatchedExecutor Executor { get; } | |||||
| /// <summary> | |||||
| /// Unique ID for this conversation | |||||
| /// </summary> | |||||
| public LLamaSeqId ConversationId { get; } | |||||
| /// <summary> | |||||
| /// Total number of tokens in this conversation, cannot exceed the context length. | |||||
| /// </summary> | |||||
| public int TokenCount => _end.Value; | |||||
| /// <summary> | |||||
| /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation | |||||
| /// </summary> | |||||
| public bool IsDisposed => _disposed || Executor.IsDisposed; | |||||
| /// <summary> | |||||
| /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true. | |||||
| /// </summary> | |||||
| public bool RequiresInference => _requiredEpoch > Executor.Epoch; | |||||
| /// <summary> | |||||
| /// Indicates that this conversation should be sampled. | |||||
| /// </summary> | |||||
| public bool RequiresSampling => _requiredEpoch == Executor.Epoch; | |||||
| #region construction/destruction | |||||
| internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end) | |||||
| { | |||||
| ConversationId = id; | |||||
| Executor = batch; | |||||
| _end = end; | |||||
| } | |||||
| ~Conversation() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// End this conversation, freeing all resources used by it | |||||
| /// </summary> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public void Dispose() | |||||
| { | |||||
| if (IsDisposed) | |||||
| return; | |||||
| _disposed = true; | |||||
| // Remove this conversation from the KV cache | |||||
| Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end); | |||||
| // Prevent finalizer from running | |||||
| GC.SuppressFinalize(this); | |||||
| } | |||||
| private void AssertNotDisposed() | |||||
| { | |||||
| if (Executor.IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(BatchedExecutor)); | |||||
| if (IsDisposed) | |||||
| throw new ObjectDisposedException(nameof(Conversation)); | |||||
| } | |||||
| #endregion | |||||
| /// <summary> | |||||
| /// Create a copy of the current conversation | |||||
| /// </summary> | |||||
| /// <remarks>The copy shares internal state, so consumes very little extra memory.</remarks> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public Conversation Fork() | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (RequiresInference) | |||||
| throw new CannotForkWhileRequiresInference(); | |||||
| // Create a new conversation which references the current position in this one | |||||
| var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | |||||
| { | |||||
| _batchIndex = _batchIndex, | |||||
| _requiredEpoch = _requiredEpoch, | |||||
| }; | |||||
| // Assign tokens to the new sequence | |||||
| NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); | |||||
| return c; | |||||
| } | |||||
| #region sample | |||||
| /// <summary> | |||||
| /// Get the logits from this conversation, ready for sampling | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception> | |||||
| /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception> | |||||
| public ReadOnlySpan<float> Sample() | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (_requiredEpoch < Executor.Epoch) | |||||
| throw new CannotSampleRequiresPromptException(); | |||||
| if (_requiredEpoch > Executor.Epoch) | |||||
| throw new CannotSampleRequiresInferenceException(); | |||||
| return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); | |||||
| } | |||||
| #endregion | |||||
| #region prompt | |||||
| private void AssertCanBePrompted() | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (RequiresInference) | |||||
| throw new AlreadyPromptedConversationException(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Add tokens to this conversation | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <returns></returns> | |||||
| public void Prompt(string input) | |||||
| { | |||||
| AssertCanBePrompted(); | |||||
| Prompt(Executor.Context.Tokenize(input)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Add tokens to this conversation | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| public void Prompt(IReadOnlyList<LLamaToken> tokens) | |||||
| { | |||||
| AssertCanBePrompted(); | |||||
| // Add the prompt to the batch | |||||
| for (var i = 0; i < tokens.Count; i++) | |||||
| _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); | |||||
| // Mark this conversation as needing inference/sampling | |||||
| _requiredEpoch = Executor.Epoch + 1; | |||||
| } | |||||
| /// <summary> | |||||
| /// Add a single token to this conversation | |||||
| /// </summary> | |||||
| /// <param name="token"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ObjectDisposedException"></exception> | |||||
| /// <exception cref="InvalidOperationException"></exception> | |||||
| public void Prompt(LLamaToken token) | |||||
| { | |||||
| AssertCanBePrompted(); | |||||
| // Add this token as input | |||||
| _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true); | |||||
| // Mark this conversation as needing inference/sampling | |||||
| _requiredEpoch = Executor.Epoch + 1; | |||||
| } | |||||
| #endregion | |||||
| #region modify | |||||
| /// <summary> | |||||
| /// Directly modify the KV cache of this conversation | |||||
| /// </summary> | |||||
| /// <param name="modifier"></param> | |||||
| /// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||||
| public void Modify(ModifyKvCache modifier) | |||||
| { | |||||
| AssertNotDisposed(); | |||||
| if (RequiresInference) | |||||
| throw new CannotModifyWhileRequiresInference(); | |||||
| // do whatever the modification is | |||||
| _end = modifier.Invoke(_end, new KvAccessor(this)); | |||||
| // Set the epoch down to zero, this ensures that this conversation | |||||
| // cannot be sampled until it is prompted again. | |||||
| _requiredEpoch = 0; | |||||
| } | |||||
| /// <summary> | |||||
| /// Provides direct access to the KV cache of a <see cref="Conversation"/>. | |||||
| /// See <see cref="Modify"/> for how to use this. | |||||
| /// </summary> | |||||
| public readonly ref struct KvAccessor | |||||
| { | |||||
| private readonly Conversation _conversation; | |||||
| internal KvAccessor(Conversation conversation) | |||||
| { | |||||
| _conversation = conversation; | |||||
| } | |||||
| #region remove | |||||
| /// <summary> | |||||
| /// Removes all tokens that have positions in [start, end) | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="end">End position (exclusive)</param> | |||||
| public void Remove(LLamaPos start, LLamaPos end) | |||||
| { | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||||
| } | |||||
| /// <summary> | |||||
| /// Removes all tokens starting from the given position | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="count">Number of tokens</param> | |||||
| public void Remove(LLamaPos start, int count) | |||||
| { | |||||
| if (count <= 0) | |||||
| return; | |||||
| var end = start.Value + count; | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); | |||||
| } | |||||
| #endregion | |||||
| #region shift | |||||
| /// <summary> | |||||
| /// Adds relative position "delta" to all tokens that have positions in [p0, p1). | |||||
| /// If the KV cache is RoPEd, the KV data is updated | |||||
| /// accordingly | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive)</param> | |||||
| /// <param name="end">End position (exclusive)</param> | |||||
| /// <param name="delta">Amount to add on to each token position</param> | |||||
| public void Shift(LLamaPos start, LLamaPos end, int delta) | |||||
| { | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); | |||||
| } | |||||
| #endregion | |||||
| #region divide | |||||
| /// <summary> | |||||
| /// Integer division of the positions by factor of `d > 1`. | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly. | |||||
| /// </summary> | |||||
| /// <param name="start">Start position (inclusive). If less than zero, it is clamped to zero.</param> | |||||
| /// <param name="end">End position (exclusive). If less than zero, it is treated as "infinity".</param> | |||||
| /// <param name="divisor">Amount to divide each position by.</param> | |||||
| public void Divide(LLamaPos start, LLamaPos end, int divisor) | |||||
| { | |||||
| if (divisor <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(divisor)); | |||||
| _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor); | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| /// <summary> | |||||
| /// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly | |||||
| /// </summary> | |||||
| /// <param name="end">The current end token of this conversation</param> | |||||
| /// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param> | |||||
| /// <returns>The new end token position</returns> | |||||
| public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); | |||||
| #endregion | |||||
| } | |||||
| @@ -0,0 +1,59 @@ | |||||
| using System; | |||||
| namespace LLama.Batched; | |||||
| /// <summary> | |||||
| /// Extension method for <see cref="Conversation"/> | |||||
| /// </summary> | |||||
| public static class ConversationExtensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end | |||||
| /// </summary> | |||||
| /// <param name="conversation">The conversation to rewind</param> | |||||
| /// <param name="tokens">The number of tokens to rewind</param> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if `tokens` parameter is larger than TokenCount</exception> | |||||
| public static void Rewind(this Conversation conversation, int tokens) | |||||
| { | |||||
| if (tokens > conversation.TokenCount) | |||||
| throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); | |||||
| conversation.Modify((end, kv) => | |||||
| { | |||||
| // Remove those tokens from KV | |||||
| kv.Remove(end.Value - tokens, tokens); | |||||
| // Return adjusted end position | |||||
| return end.Value - tokens; | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over. | |||||
| /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context | |||||
| /// gets full, keeping the prompt at the start intact. | |||||
| /// </summary> | |||||
| /// <param name="conversation">The conversation to rewind</param> | |||||
| /// <param name="count">How much to shift tokens over by</param> | |||||
| /// <param name="keep">The number of tokens at the start which should <b>not</b> be shifted</param> | |||||
| public static void ShiftLeft(this Conversation conversation, int count, int keep) | |||||
| { | |||||
| // Given a setup like this (shift=5, keep=3): | |||||
| // | |||||
| // AAABBBBBCCCCCCCCC... | |||||
| // | |||||
| // We want to remove all the B's, shift all the C's and leave all the A's untouched | |||||
| conversation.Modify((end, kv) => | |||||
| { | |||||
| // Remove the B's | |||||
| kv.Remove(keep, count); | |||||
| // Shift the C's | |||||
| kv.Shift(keep + count, end, -count); | |||||
| // Update total count | |||||
| return end.Value - count; | |||||
| }); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,81 @@ | |||||
| using System; | |||||
| namespace LLama.Batched; | |||||
| /// <summary> | |||||
| /// Base class for exceptions thrown from <see cref="BatchedExecutor"/> | |||||
| /// </summary> | |||||
| public class ExperimentalBatchedExecutorException | |||||
| : Exception | |||||
| { | |||||
| internal ExperimentalBatchedExecutorException(string message) | |||||
| : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// This exception is thrown when "Prompt()" is called on a <see cref="Conversation"/> which has | |||||
| /// already been prompted and before "Infer()" has been called on the associated | |||||
| /// <see cref="BatchedExecutor"/>. | |||||
| /// </summary> | |||||
| public class AlreadyPromptedConversationException | |||||
| : ExperimentalBatchedExecutorException | |||||
| { | |||||
| internal AlreadyPromptedConversationException() | |||||
| : base("Must call `Infer()` before prompting this Conversation again") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// This exception is thrown when "Sample()" is called on a <see cref="Conversation"/> which has | |||||
| /// already been prompted and before "Infer()" has been called on the associated | |||||
| /// <see cref="BatchedExecutor"/>. | |||||
| /// </summary> | |||||
| public class CannotSampleRequiresInferenceException | |||||
| : ExperimentalBatchedExecutorException | |||||
| { | |||||
| internal CannotSampleRequiresInferenceException() | |||||
| : base("Must call `Infer()` before sampling from this Conversation") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// This exception is thrown when "Sample()" is called on a <see cref="Conversation"/> which was not | |||||
| /// first prompted. | |||||
| /// <see cref="BatchedExecutor"/>. | |||||
| /// </summary> | |||||
| public class CannotSampleRequiresPromptException | |||||
| : ExperimentalBatchedExecutorException | |||||
| { | |||||
| internal CannotSampleRequiresPromptException() | |||||
| : base("Must call `Prompt()` and then `Infer()` before sampling from this Conversation") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// </summary> | |||||
| public class CannotForkWhileRequiresInference | |||||
| : ExperimentalBatchedExecutorException | |||||
| { | |||||
| internal CannotForkWhileRequiresInference() | |||||
| : base("Cannot `Fork()` a conversation while RequiresInference is true") | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||||
| /// </summary> | |||||
| public class CannotModifyWhileRequiresInference | |||||
| : ExperimentalBatchedExecutorException | |||||
| { | |||||
| internal CannotModifyWhileRequiresInference() | |||||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -221,7 +221,9 @@ namespace LLama | |||||
| /// <returns>The selected token</returns> | /// <returns>The selected token</returns> | ||||
| public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens) | public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens) | ||||
| { | { | ||||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||||
| var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||||
| pipeline.Accept(NativeHandle, token); | |||||
| return token; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -213,6 +213,7 @@ namespace LLama | |||||
| if (inferenceParams.SamplingPipeline is not null) | if (inferenceParams.SamplingPipeline is not null) | ||||
| { | { | ||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | ||||
| inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -192,6 +192,7 @@ namespace LLama | |||||
| if (inferenceParams.SamplingPipeline is not null) | if (inferenceParams.SamplingPipeline is not null) | ||||
| { | { | ||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | ||||
| inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -18,6 +18,11 @@ public class LLamaBatch | |||||
| private LLamaSeqId[][] _sequenceIds; | private LLamaSeqId[][] _sequenceIds; | ||||
| private IntPtr[] _sequenceIdsPtrs; | private IntPtr[] _sequenceIdsPtrs; | ||||
| /// <summary> | |||||
| /// Keep track of the index of existing token/position combos in the batch | |||||
| /// </summary> | |||||
| private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new(); | |||||
| /// <summary> | /// <summary> | ||||
| /// The number of tokens in this batch | /// The number of tokens in this batch | ||||
| /// </summary> | /// </summary> | ||||
| @@ -130,23 +135,44 @@ public class LLamaBatch | |||||
| /// <param name="pos">The position to add it att</param> | /// <param name="pos">The position to add it att</param> | ||||
| /// <param name="sequences">The set of sequences to add this token to</param> | /// <param name="sequences">The set of sequences to add this token to</param> | ||||
| /// <param name="logits"></param> | /// <param name="logits"></param> | ||||
| public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> | |||||
| public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||||
| { | { | ||||
| // Try to find this (token, position) combo somewhere in the batch to re-use it | |||||
| if (_index.TryGetValue((token, pos), out var existingIndex)) | |||||
| { | |||||
| if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) | |||||
| GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); | |||||
| foreach (var sequence in sequences) | |||||
| { | |||||
| _sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence; | |||||
| _sequenceIdCount[existingIndex]++; | |||||
| } | |||||
| return existingIndex; | |||||
| } | |||||
| // Couldn't find this it in the batch, add a new item | |||||
| // Frow capacity as necessary | |||||
| if (TokenCount == TokenCapacity) | if (TokenCount == TokenCapacity) | ||||
| GrowTokenCapacity(); | GrowTokenCapacity(); | ||||
| if (sequences.Length > SequenceCapacity) | if (sequences.Length > SequenceCapacity) | ||||
| GrowMaxSequences(sequences.Length); | GrowMaxSequences(sequences.Length); | ||||
| // Store the position in the index, so it can be found later | |||||
| _index.Add((token, pos), TokenCount); | |||||
| // Add the items to the arrays | |||||
| _tokens[TokenCount] = token; | _tokens[TokenCount] = token; | ||||
| _positions[TokenCount] = pos; | _positions[TokenCount] = pos; | ||||
| _sequenceIdCount[TokenCount] = sequences.Length; | _sequenceIdCount[TokenCount] = sequences.Length; | ||||
| for (var i = 0; i < sequences.Length; i++) | for (var i = 0; i < sequences.Length; i++) | ||||
| _sequenceIds[TokenCount][i] = sequences[i]; | _sequenceIds[TokenCount][i] = sequences[i]; | ||||
| _logits[TokenCount] = Convert.ToByte(logits); | _logits[TokenCount] = Convert.ToByte(logits); | ||||
| TokenCount++; | |||||
| return TokenCount++; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -157,11 +183,12 @@ public class LLamaBatch | |||||
| /// <param name="pos">The position to add it att</param> | /// <param name="pos">The position to add it att</param> | ||||
| /// <param name="sequences">The set of sequences to add this token to</param> | /// <param name="sequences">The set of sequences to add this token to</param> | ||||
| /// <param name="logits"></param> | /// <param name="logits"></param> | ||||
| public void Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits) | |||||
| /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> | |||||
| public int Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits) | |||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| var seqSpan = CollectionsMarshal.AsSpan(sequences); | var seqSpan = CollectionsMarshal.AsSpan(sequences); | ||||
| Add(token, pos, seqSpan, logits); | |||||
| return Add(token, pos, seqSpan, logits); | |||||
| #else | #else | ||||
| // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of | // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of | ||||
| // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't | ||||
| @@ -171,7 +198,7 @@ public class LLamaBatch | |||||
| try | try | ||||
| { | { | ||||
| sequences.CopyTo(rented, 0); | sequences.CopyTo(rented, 0); | ||||
| Add(token, pos, rented.AsSpan(0, sequences.Count), logits); | |||||
| return Add(token, pos, rented.AsSpan(0, sequences.Count), logits); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -188,14 +215,15 @@ public class LLamaBatch | |||||
| /// <param name="pos">The position to add it att</param> | /// <param name="pos">The position to add it att</param> | ||||
| /// <param name="sequence">The sequence to add this token to</param> | /// <param name="sequence">The sequence to add this token to</param> | ||||
| /// <param name="logits"></param> | /// <param name="logits"></param> | ||||
| public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||||
| /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> | |||||
| public int Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | |||||
| { | { | ||||
| // Create a temporary span to contain 1 item without allocating | // Create a temporary span to contain 1 item without allocating | ||||
| Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | ||||
| sequences[0] = sequence; | sequences[0] = sequence; | ||||
| // Add it | // Add it | ||||
| Add(token, pos, sequences, logits); | |||||
| return Add(token, pos, sequences, logits); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -205,13 +233,17 @@ public class LLamaBatch | |||||
| /// <param name="start">The starting position to add tokens at</param> | /// <param name="start">The starting position to add tokens at</param> | ||||
| /// <param name="sequence">The sequence to add this token to</param> | /// <param name="sequence">The sequence to add this token to</param> | ||||
| /// <param name="logitsLast">Whether the final token should generate logits</param> | /// <param name="logitsLast">Whether the final token should generate logits</param> | ||||
| public void AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) | |||||
| /// <returns>The index that the final token was added at. Use this for GetLogitsIth</returns> | |||||
| public int AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) | |||||
| { | { | ||||
| var last = -1; | |||||
| for (var i = 0; i < tokens.Length; i++) | for (var i = 0; i < tokens.Length; i++) | ||||
| { | { | ||||
| var logits = (i == tokens.Length - 1) & logitsLast; | var logits = (i == tokens.Length - 1) & logitsLast; | ||||
| Add(tokens[i], start.Value + i, sequence, logits); | |||||
| last = Add(tokens[i], start.Value + i, sequence, logits); | |||||
| } | } | ||||
| return last; | |||||
| } | } | ||||
| #endregion | #endregion | ||||
| @@ -221,5 +253,6 @@ public class LLamaBatch | |||||
| public void Clear() | public void Clear() | ||||
| { | { | ||||
| TokenCount = 0; | TokenCount = 0; | ||||
| _index.Clear(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -388,7 +388,7 @@ namespace LLama.Native | |||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); | |||||
| public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); | |||||
| /// <summary> | /// <summary> | ||||
| /// Integer division of the positions by factor of `d > 1` | /// Integer division of the positions by factor of `d > 1` | ||||
| @@ -369,15 +369,15 @@ namespace LLama.Native | |||||
| /// <param name="p0"></param> | /// <param name="p0"></param> | ||||
| /// <param name="p1"></param> | /// <param name="p1"></param> | ||||
| /// <param name="delta"></param> | /// <param name="delta"></param> | ||||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) | |||||
| public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) | |||||
| { | { | ||||
| NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Integer division of the positions by factor of `d > 1` | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly | |||||
| /// p0 < 0 : [0, p1] | |||||
| /// Integer division of the positions by factor of `d > 1`. | |||||
| /// If the KV cache is RoPEd, the KV data is updated accordingly.<br /> | |||||
| /// p0 < 0 : [0, p1]<br /> | |||||
| /// p1 < 0 : [p0, inf) | /// p1 < 0 : [p0, inf) | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="seq"></param> | /// <param name="seq"></param> | ||||
| @@ -40,10 +40,7 @@ public abstract class BaseSamplingPipeline | |||||
| var candidates = LLamaTokenDataArray.Create(logits); | var candidates = LLamaTokenDataArray.Create(logits); | ||||
| // Process token data array | // Process token data array | ||||
| ProcessTokenDataArray(ctx, candidates, lastTokens); | |||||
| // Choose the final value | |||||
| return ChooseToken(ctx, candidates); | |||||
| return ProcessTokenDataArray(ctx, candidates, lastTokens); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -53,6 +50,9 @@ public abstract class BaseSamplingPipeline | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); | |||||
| #region protected tokens | #region protected tokens | ||||
| /// <summary> | /// <summary> | ||||
| /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | /// Get all of the "protected" tokens that cannot be changed by ProcessLogits | ||||
| @@ -107,19 +107,14 @@ public abstract class BaseSamplingPipeline | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens); | protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens); | ||||
| /// <summary> | |||||
| /// Choose the final token from the candidates | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="candidates"></param> | |||||
| /// <returns></returns> | |||||
| protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); | |||||
| /// <inheritdoc/> | /// <inheritdoc/> | ||||
| public virtual void Reset() | public virtual void Reset() | ||||
| { | { | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public abstract ISamplingPipeline Clone(); | |||||
| /// <inheritdoc/> | /// <inheritdoc/> | ||||
| public virtual void Dispose() | public virtual void Dispose() | ||||
| { | { | ||||
| @@ -141,9 +141,31 @@ public sealed class DefaultSamplingPipeline | |||||
| return id; | return id; | ||||
| } | } | ||||
| public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | |||||
| { | |||||
| Grammar?.AcceptToken(ctx, token); | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||||
| public override ISamplingPipeline Clone() | |||||
| { | { | ||||
| return candidates.SampleToken(ctx); | |||||
| var clone = new DefaultSamplingPipeline(); | |||||
| foreach (var (k, v) in LogitBias) | |||||
| clone.LogitBias.Add(k, v); | |||||
| clone.Grammar = Grammar?.Clone(); | |||||
| clone.RepeatPenalty = RepeatPenalty; | |||||
| clone.AlphaFrequency = AlphaFrequency; | |||||
| clone.AlphaPresence = AlphaPresence; | |||||
| clone.Temperature = Temperature; | |||||
| clone.TopK = TopK; | |||||
| clone.TailFreeZ = TailFreeZ; | |||||
| clone.TypicalP = TypicalP; | |||||
| clone.TopP = TopP; | |||||
| clone.MinP = MinP; | |||||
| clone.PenalizeNewline = PenalizeNewline; | |||||
| return clone; | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,10 +21,23 @@ public interface ISamplingPipeline | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | ||||
| /// <summary> | |||||
| /// Update the pipeline, with knowledge that a particular token was just accepted | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="token"></param> | |||||
| void Accept(SafeLLamaContextHandle ctx, LLamaToken token); | |||||
| /// <summary> | /// <summary> | ||||
| /// Reset all internal state of the sampling pipeline | /// Reset all internal state of the sampling pipeline | ||||
| /// </summary> | /// </summary> | ||||
| void Reset(); | void Reset(); | ||||
| /// <summary> | |||||
| /// Create a copy of this sampling pipeline | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| ISamplingPipeline Clone(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||