From f0b0bbcbb718c5732f15042271235aab33de8c79 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 10 Mar 2024 13:56:11 +0000 Subject: [PATCH] Mutable Logits (#586) Modified LLamaBatch to not share tokens with other sequences if logits is true. This ensures that the logit span at the end in used by exactly one sequence - therefore it's safe to mutate. This removes the need for copying _very_ large arrays (vocab size) and simplifies sampling pipelines. --- .../Examples/BatchedExecutorGuidance.cs | 14 +++++--------- LLama/Batched/Conversation.cs | 6 +++++- LLama/Native/LLamaBatch.cs | 16 ++++++++++------ LLama/Sampling/BaseSamplingPipeline.cs | 6 +++--- LLama/Sampling/DefaultSamplingPipeline.cs | 19 ++----------------- LLama/Sampling/GreedySamplingPipeline.cs | 3 +-- LLama/Sampling/ISamplingPipeline.cs | 4 ++-- LLama/Sampling/Mirostat2SamplingPipeline.cs | 3 +-- LLama/Sampling/MirostatSamplingPipeline.cs | 3 +-- 9 files changed, 30 insertions(+), 44 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 1c2c4b49..de78085a 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -85,8 +85,8 @@ public class BatchedExecutorGuidance } }); - AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]"); - AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]"); + AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]"); + AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]"); } private class GuidedSampler(Conversation? guidance, float weight) @@ -101,20 +101,16 @@ public class BatchedExecutorGuidance throw new NotSupportedException(); } - protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { if (guidance == null) - return logits; - - var logitsCopy = logits.ToArray(); + return; // Get the logits generated by the guidance sequences var guidanceLogits = guidance.Sample(); // Use those logits to guide this sequence - NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight); - - return logitsCopy; + NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight); } protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 985f2c5a..29759cf9 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -122,7 +122,7 @@ public sealed class Conversation /// /// Thrown if this conversation was not prompted before the previous call to infer /// Thrown if Infer() must be called on the executor - public ReadOnlySpan Sample() + public Span Sample() { AssertNotDisposed(); @@ -166,6 +166,10 @@ public sealed class Conversation { AssertCanBePrompted(); + // No point doing anything if there is no actual prompt! + if (tokens.Count == 0) + return; + // Add the prompt to the batch for (var i = 0; i < tokens.Count; i++) _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 50b9c8f0..81d37c49 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -138,8 +138,10 @@ public class LLamaBatch /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { - // Try to find this (token, position) combo somewhere in the batch to re-use it - if (_index.TryGetValue((token, pos), out var existingIndex)) + // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this + // sequence ID to the list. + // Do **not** do this if this token wants logits, to prevent logits being shared between sequences. + if (!logits && _index.TryGetValue((token, pos), out var existingIndex)) { if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); @@ -153,16 +155,18 @@ public class LLamaBatch return existingIndex; } - // Couldn't find this it in the batch, add a new item + // Couldn't find this token/position combo anywhere in the batch. Add a new item. - // Frow capacity as necessary + // Grow capacity as necessary if (TokenCount == TokenCapacity) GrowTokenCapacity(); if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); - // Store the position in the index, so it can be found later - _index.Add((token, pos), TokenCount); + // Store the position in the index, so it can be found later. + // We need to check that it's not already there in case we skipped the check above (because logits is true). + if (!_index.ContainsKey((token, pos))) + _index.Add((token, pos), TokenCount); // Add the items to the arrays _tokens[TokenCount] = token; diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index aafb5932..4dcd8127 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -15,10 +15,10 @@ public abstract class BaseSamplingPipeline public SafeLLamaGrammarHandle? Grammar { get; set; } /// - public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { // Apply processing to raw logit values - logits = ProcessLogits(ctx, logits, lastTokens); + ProcessLogits(ctx, logits, lastTokens); // Process token data array to select a final token var candidates = LLamaTokenDataArray.Create(logits); @@ -38,7 +38,7 @@ public abstract class BaseSamplingPipeline /// The context being sampled from /// The logits produced by the model /// A list of tokens recently returned by the model - protected abstract ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens); + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Process the LLamaTokenDataArray and select a single token diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 071b5c19..5a9ef16c 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -94,28 +94,13 @@ public sealed class DefaultSamplingPipeline /// public bool PenalizeNewline { get; set; } = false; - private float[]? _logits; - /// - protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - // Skip work if possible - if (LogitBias.Count == 0) - return logits; - - // Create a temporary array to hold logits - if (_logits == null || _logits.Length < logits.Length) - _logits = new float[logits.Length]; - - // Copy logits - logits.CopyTo(_logits); - var mutable = _logits.AsSpan(0, logits.Length); - // Apply logit bias foreach (var (key, value) in LogitBias) - mutable[key] += value; + logits[key] += value; - return mutable; } /// diff --git a/LLama/Sampling/GreedySamplingPipeline.cs b/LLama/Sampling/GreedySamplingPipeline.cs index 81b2d3cd..df4ef34c 100644 --- a/LLama/Sampling/GreedySamplingPipeline.cs +++ b/LLama/Sampling/GreedySamplingPipeline.cs @@ -10,9 +10,8 @@ public class GreedySamplingPipeline : BaseSamplingPipeline { /// - protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - return logits; } /// diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 53c8c7c6..b538d1fe 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -19,7 +19,7 @@ public interface ISamplingPipeline /// The logits produced by the model /// A span of tokens recently returned by the model /// - LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens); + LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Update the pipeline, with knowledge that a particular token was just accepted @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions /// The logits produced by the model /// A list of tokens recently returned by the model /// - public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan logits, List lastTokens) + public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) { #if NET5_0_OR_GREATER var span = CollectionsMarshal.AsSpan(lastTokens); diff --git a/LLama/Sampling/Mirostat2SamplingPipeline.cs b/LLama/Sampling/Mirostat2SamplingPipeline.cs index dcdc4197..bdf1a461 100644 --- a/LLama/Sampling/Mirostat2SamplingPipeline.cs +++ b/LLama/Sampling/Mirostat2SamplingPipeline.cs @@ -37,9 +37,8 @@ public class Mirostate2SamplingPipeline public float Eta { get; set; } = 0.1f; /// - protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - return logits; } /// diff --git a/LLama/Sampling/MirostatSamplingPipeline.cs b/LLama/Sampling/MirostatSamplingPipeline.cs index 65d36007..1bb11138 100644 --- a/LLama/Sampling/MirostatSamplingPipeline.cs +++ b/LLama/Sampling/MirostatSamplingPipeline.cs @@ -38,9 +38,8 @@ public class MirostateSamplingPipeline public float Eta { get; set; } = 0.1f; /// - protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - return logits; } ///