Browse Source

Mutable Logits (#586)

Modified LLamaBatch to not share tokens with other sequences if logits is true. This ensures that the logit span at the end in used by exactly one sequence - therefore it's safe to mutate. This removes the need for copying _very_ large arrays (vocab size) and simplifies sampling pipelines.
tags/0.11.0
Martin Evans GitHub 1 year ago
parent
commit
f0b0bbcbb7
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
9 changed files with 30 additions and 44 deletions
  1. +5
    -9
      LLama.Examples/Examples/BatchedExecutorGuidance.cs
  2. +5
    -1
      LLama/Batched/Conversation.cs
  3. +10
    -6
      LLama/Native/LLamaBatch.cs
  4. +3
    -3
      LLama/Sampling/BaseSamplingPipeline.cs
  5. +2
    -17
      LLama/Sampling/DefaultSamplingPipeline.cs
  6. +1
    -2
      LLama/Sampling/GreedySamplingPipeline.cs
  7. +2
    -2
      LLama/Sampling/ISamplingPipeline.cs
  8. +1
    -2
      LLama/Sampling/Mirostat2SamplingPipeline.cs
  9. +1
    -2
      LLama/Sampling/MirostatSamplingPipeline.cs

+ 5
- 9
LLama.Examples/Examples/BatchedExecutorGuidance.cs View File

@@ -85,8 +85,8 @@ public class BatchedExecutorGuidance
} }
}); });


AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]");
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]");
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
} }


private class GuidedSampler(Conversation? guidance, float weight) private class GuidedSampler(Conversation? guidance, float weight)
@@ -101,20 +101,16 @@ public class BatchedExecutorGuidance
throw new NotSupportedException(); throw new NotSupportedException();
} }


protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
if (guidance == null) if (guidance == null)
return logits;

var logitsCopy = logits.ToArray();
return;


// Get the logits generated by the guidance sequences // Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample(); var guidanceLogits = guidance.Sample();


// Use those logits to guide this sequence // Use those logits to guide this sequence
NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight);

return logitsCopy;
NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
} }


protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)


+ 5
- 1
LLama/Batched/Conversation.cs View File

@@ -122,7 +122,7 @@ public sealed class Conversation
/// <exception cref="ObjectDisposedException"></exception> /// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception> /// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
/// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception> /// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
public ReadOnlySpan<float> Sample()
public Span<float> Sample()
{ {
AssertNotDisposed(); AssertNotDisposed();


@@ -166,6 +166,10 @@ public sealed class Conversation
{ {
AssertCanBePrompted(); AssertCanBePrompted();


// No point doing anything if there is no actual prompt!
if (tokens.Count == 0)
return;

// Add the prompt to the batch // Add the prompt to the batch
for (var i = 0; i < tokens.Count; i++) for (var i = 0; i < tokens.Count; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);


+ 10
- 6
LLama/Native/LLamaBatch.cs View File

@@ -138,8 +138,10 @@ public class LLamaBatch
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns> /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{ {
// Try to find this (token, position) combo somewhere in the batch to re-use it
if (_index.TryGetValue((token, pos), out var existingIndex))
// Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
// sequence ID to the list.
// Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
if (!logits && _index.TryGetValue((token, pos), out var existingIndex))
{ {
if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);
@@ -153,16 +155,18 @@ public class LLamaBatch
return existingIndex; return existingIndex;
} }


// Couldn't find this it in the batch, add a new item
// Couldn't find this token/position combo anywhere in the batch. Add a new item.


// Frow capacity as necessary
// Grow capacity as necessary
if (TokenCount == TokenCapacity) if (TokenCount == TokenCapacity)
GrowTokenCapacity(); GrowTokenCapacity();
if (sequences.Length > SequenceCapacity) if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length); GrowMaxSequences(sequences.Length);


// Store the position in the index, so it can be found later
_index.Add((token, pos), TokenCount);
// Store the position in the index, so it can be found later.
// We need to check that it's not already there in case we skipped the check above (because logits is true).
if (!_index.ContainsKey((token, pos)))
_index.Add((token, pos), TokenCount);


// Add the items to the arrays // Add the items to the arrays
_tokens[TokenCount] = token; _tokens[TokenCount] = token;


+ 3
- 3
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -15,10 +15,10 @@ public abstract class BaseSamplingPipeline
public SafeLLamaGrammarHandle? Grammar { get; set; } public SafeLLamaGrammarHandle? Grammar { get; set; }


/// <inheritdoc/> /// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
// Apply processing to raw logit values // Apply processing to raw logit values
logits = ProcessLogits(ctx, logits, lastTokens);
ProcessLogits(ctx, logits, lastTokens);


// Process token data array to select a final token // Process token data array to select a final token
var candidates = LLamaTokenDataArray.Create(logits); var candidates = LLamaTokenDataArray.Create(logits);
@@ -38,7 +38,7 @@ public abstract class BaseSamplingPipeline
/// <param name="ctx">The context being sampled from</param> /// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param> /// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);


/// <summary> /// <summary>
/// Process the LLamaTokenDataArray and select a single token /// Process the LLamaTokenDataArray and select a single token


+ 2
- 17
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -94,28 +94,13 @@ public sealed class DefaultSamplingPipeline
/// </summary> /// </summary>
public bool PenalizeNewline { get; set; } = false; public bool PenalizeNewline { get; set; } = false;


private float[]? _logits;

/// <inheritdoc /> /// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
// Skip work if possible
if (LogitBias.Count == 0)
return logits;

// Create a temporary array to hold logits
if (_logits == null || _logits.Length < logits.Length)
_logits = new float[logits.Length];

// Copy logits
logits.CopyTo(_logits);
var mutable = _logits.AsSpan(0, logits.Length);

// Apply logit bias // Apply logit bias
foreach (var (key, value) in LogitBias) foreach (var (key, value) in LogitBias)
mutable[key] += value;
logits[key] += value;


return mutable;
} }


/// <inheritdoc /> /// <inheritdoc />


+ 1
- 2
LLama/Sampling/GreedySamplingPipeline.cs View File

@@ -10,9 +10,8 @@ public class GreedySamplingPipeline
: BaseSamplingPipeline : BaseSamplingPipeline
{ {
/// <inheritdoc /> /// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
return logits;
} }


/// <inheritdoc /> /// <inheritdoc />


+ 2
- 2
LLama/Sampling/ISamplingPipeline.cs View File

@@ -19,7 +19,7 @@ public interface ISamplingPipeline
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param> /// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns> /// <returns></returns>
LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);


/// <summary> /// <summary>
/// Update the pipeline, with knowledge that a particular token was just accepted /// Update the pipeline, with knowledge that a particular token was just accepted
@@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions
/// <param name="logits">The logits produced by the model</param> /// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param> /// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns> /// <returns></returns>
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens)
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens); var span = CollectionsMarshal.AsSpan(lastTokens);


+ 1
- 2
LLama/Sampling/Mirostat2SamplingPipeline.cs View File

@@ -37,9 +37,8 @@ public class Mirostate2SamplingPipeline
public float Eta { get; set; } = 0.1f; public float Eta { get; set; } = 0.1f;


/// <inheritdoc /> /// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
return logits;
} }


/// <inheritdoc /> /// <inheritdoc />


+ 1
- 2
LLama/Sampling/MirostatSamplingPipeline.cs View File

@@ -38,9 +38,8 @@ public class MirostateSamplingPipeline
public float Eta { get; set; } = 0.1f; public float Eta { get; set; } = 0.1f;


/// <inheritdoc /> /// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{ {
return logits;
} }


/// <inheritdoc /> /// <inheritdoc />


Loading…
Cancel
Save