Modified LLamaBatch to not share tokens with other sequences if logits is true. This ensures that the logit span at the end in used by exactly one sequence - therefore it's safe to mutate. This removes the need for copying _very_ large arrays (vocab size) and simplifies sampling pipelines.tags/0.11.0
| @@ -85,8 +85,8 @@ public class BatchedExecutorGuidance | |||||
| } | } | ||||
| }); | }); | ||||
| AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]"); | |||||
| AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]"); | |||||
| AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]"); | |||||
| AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]"); | |||||
| } | } | ||||
| private class GuidedSampler(Conversation? guidance, float weight) | private class GuidedSampler(Conversation? guidance, float weight) | ||||
| @@ -101,20 +101,16 @@ public class BatchedExecutorGuidance | |||||
| throw new NotSupportedException(); | throw new NotSupportedException(); | ||||
| } | } | ||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| if (guidance == null) | if (guidance == null) | ||||
| return logits; | |||||
| var logitsCopy = logits.ToArray(); | |||||
| return; | |||||
| // Get the logits generated by the guidance sequences | // Get the logits generated by the guidance sequences | ||||
| var guidanceLogits = guidance.Sample(); | var guidanceLogits = guidance.Sample(); | ||||
| // Use those logits to guide this sequence | // Use those logits to guide this sequence | ||||
| NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight); | |||||
| return logitsCopy; | |||||
| NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight); | |||||
| } | } | ||||
| protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | ||||
| @@ -122,7 +122,7 @@ public sealed class Conversation | |||||
| /// <exception cref="ObjectDisposedException"></exception> | /// <exception cref="ObjectDisposedException"></exception> | ||||
| /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception> | /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception> | ||||
| /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception> | /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception> | ||||
| public ReadOnlySpan<float> Sample() | |||||
| public Span<float> Sample() | |||||
| { | { | ||||
| AssertNotDisposed(); | AssertNotDisposed(); | ||||
| @@ -166,6 +166,10 @@ public sealed class Conversation | |||||
| { | { | ||||
| AssertCanBePrompted(); | AssertCanBePrompted(); | ||||
| // No point doing anything if there is no actual prompt! | |||||
| if (tokens.Count == 0) | |||||
| return; | |||||
| // Add the prompt to the batch | // Add the prompt to the batch | ||||
| for (var i = 0; i < tokens.Count; i++) | for (var i = 0; i < tokens.Count; i++) | ||||
| _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); | _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); | ||||
| @@ -138,8 +138,10 @@ public class LLamaBatch | |||||
| /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> | /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> | ||||
| public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | ||||
| { | { | ||||
| // Try to find this (token, position) combo somewhere in the batch to re-use it | |||||
| if (_index.TryGetValue((token, pos), out var existingIndex)) | |||||
| // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this | |||||
| // sequence ID to the list. | |||||
| // Do **not** do this if this token wants logits, to prevent logits being shared between sequences. | |||||
| if (!logits && _index.TryGetValue((token, pos), out var existingIndex)) | |||||
| { | { | ||||
| if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) | if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) | ||||
| GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); | GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); | ||||
| @@ -153,16 +155,18 @@ public class LLamaBatch | |||||
| return existingIndex; | return existingIndex; | ||||
| } | } | ||||
| // Couldn't find this it in the batch, add a new item | |||||
| // Couldn't find this token/position combo anywhere in the batch. Add a new item. | |||||
| // Frow capacity as necessary | |||||
| // Grow capacity as necessary | |||||
| if (TokenCount == TokenCapacity) | if (TokenCount == TokenCapacity) | ||||
| GrowTokenCapacity(); | GrowTokenCapacity(); | ||||
| if (sequences.Length > SequenceCapacity) | if (sequences.Length > SequenceCapacity) | ||||
| GrowMaxSequences(sequences.Length); | GrowMaxSequences(sequences.Length); | ||||
| // Store the position in the index, so it can be found later | |||||
| _index.Add((token, pos), TokenCount); | |||||
| // Store the position in the index, so it can be found later. | |||||
| // We need to check that it's not already there in case we skipped the check above (because logits is true). | |||||
| if (!_index.ContainsKey((token, pos))) | |||||
| _index.Add((token, pos), TokenCount); | |||||
| // Add the items to the arrays | // Add the items to the arrays | ||||
| _tokens[TokenCount] = token; | _tokens[TokenCount] = token; | ||||
| @@ -15,10 +15,10 @@ public abstract class BaseSamplingPipeline | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | public SafeLLamaGrammarHandle? Grammar { get; set; } | ||||
| /// <inheritdoc/> | /// <inheritdoc/> | ||||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| // Apply processing to raw logit values | // Apply processing to raw logit values | ||||
| logits = ProcessLogits(ctx, logits, lastTokens); | |||||
| ProcessLogits(ctx, logits, lastTokens); | |||||
| // Process token data array to select a final token | // Process token data array to select a final token | ||||
| var candidates = LLamaTokenDataArray.Create(logits); | var candidates = LLamaTokenDataArray.Create(logits); | ||||
| @@ -38,7 +38,7 @@ public abstract class BaseSamplingPipeline | |||||
| /// <param name="ctx">The context being sampled from</param> | /// <param name="ctx">The context being sampled from</param> | ||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | /// <param name="lastTokens">A list of tokens recently returned by the model</param> | ||||
| protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| /// <summary> | /// <summary> | ||||
| /// Process the LLamaTokenDataArray and select a single token | /// Process the LLamaTokenDataArray and select a single token | ||||
| @@ -94,28 +94,13 @@ public sealed class DefaultSamplingPipeline | |||||
| /// </summary> | /// </summary> | ||||
| public bool PenalizeNewline { get; set; } = false; | public bool PenalizeNewline { get; set; } = false; | ||||
| private float[]? _logits; | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| // Skip work if possible | |||||
| if (LogitBias.Count == 0) | |||||
| return logits; | |||||
| // Create a temporary array to hold logits | |||||
| if (_logits == null || _logits.Length < logits.Length) | |||||
| _logits = new float[logits.Length]; | |||||
| // Copy logits | |||||
| logits.CopyTo(_logits); | |||||
| var mutable = _logits.AsSpan(0, logits.Length); | |||||
| // Apply logit bias | // Apply logit bias | ||||
| foreach (var (key, value) in LogitBias) | foreach (var (key, value) in LogitBias) | ||||
| mutable[key] += value; | |||||
| logits[key] += value; | |||||
| return mutable; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -10,9 +10,8 @@ public class GreedySamplingPipeline | |||||
| : BaseSamplingPipeline | : BaseSamplingPipeline | ||||
| { | { | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| return logits; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -19,7 +19,7 @@ public interface ISamplingPipeline | |||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | /// <param name="lastTokens">A span of tokens recently returned by the model</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens); | |||||
| /// <summary> | /// <summary> | ||||
| /// Update the pipeline, with knowledge that a particular token was just accepted | /// Update the pipeline, with knowledge that a particular token was just accepted | ||||
| @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions | |||||
| /// <param name="logits">The logits produced by the model</param> | /// <param name="logits">The logits produced by the model</param> | ||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | /// <param name="lastTokens">A list of tokens recently returned by the model</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens) | |||||
| public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens) | |||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| var span = CollectionsMarshal.AsSpan(lastTokens); | var span = CollectionsMarshal.AsSpan(lastTokens); | ||||
| @@ -37,9 +37,8 @@ public class Mirostate2SamplingPipeline | |||||
| public float Eta { get; set; } = 0.1f; | public float Eta { get; set; } = 0.1f; | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| return logits; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -38,9 +38,8 @@ public class MirostateSamplingPipeline | |||||
| public float Eta { get; set; } = 0.1f; | public float Eta { get; set; } = 0.1f; | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | |||||
| { | { | ||||
| return logits; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||