Browse Source

Merge pull request #348 from martindevans/new_object_based_sampling_pipeline

Custom Sampling Pipelines
tags/0.9.1
Martin Evans GitHub 1 year ago
parent
commit
d87d654a34
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 462 additions and 39 deletions
  1. +1
    -2
      LLama.Unittest/GrammarParserTest.cs
  2. +5
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  3. +8
    -2
      LLama.Web/Common/InferenceOptions.cs
  4. +6
    -0
      LLama/Abstractions/IInferenceParams.cs
  5. +4
    -0
      LLama/Common/InferenceParams.cs
  6. +12
    -0
      LLama/LLamaContext.cs
  7. +17
    -9
      LLama/LLamaInstructExecutor.cs
  8. +18
    -10
      LLama/LLamaInteractExecutor.cs
  9. +19
    -10
      LLama/LLamaStatelessExecutor.cs
  10. +34
    -5
      LLama/Native/LLamaTokenDataArray.cs
  11. +128
    -0
      LLama/Sampling/BaseSamplingPipeline.cs
  12. +149
    -0
      LLama/Sampling/DefaultSamplingPipeline.cs
  13. +61
    -0
      LLama/Sampling/ISamplingPipeline.cs

+ 1
- 2
LLama.Unittest/GrammarParserTest.cs View File

@@ -1,5 +1,4 @@
using System.Text;
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;



+ 5
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -1,5 +1,6 @@
using System.Diagnostics;
using LLama.Common;
using LLama.Sampling;
using Xunit.Abstractions;

namespace LLama.Unittest
@@ -30,10 +31,13 @@ namespace LLama.Unittest
[Fact]
public async Task Stateless()
{
// Create a custom pipeline that mimics the default pipeline
var pipeline = new DefaultSamplingPipeline();

var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };

var timer = new Stopwatch();
timer.Start();


+ 8
- 2
LLama.Web/Common/InferenceOptions.cs View File

@@ -1,6 +1,9 @@
using LLama.Common;
#nullable enable

using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Web.Common
{
@@ -64,6 +67,9 @@ namespace LLama.Web.Common
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 6
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Abstractions
{
@@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 4
- 0
LLama/Common/InferenceParams.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Common
{
@@ -76,6 +77,9 @@ namespace LLama.Common

/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}

/// <summary>


+ 12
- 0
LLama/LLamaContext.cs View File

@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -212,6 +213,17 @@ namespace LLama
}
}

/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}

/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>


+ 17
- 9
LLama/LLamaInstructExecutor.cs View File

@@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 18
- 10
LLama/LLamaInteractExecutor.cs View File

@@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 19
- 10
LLama/LLamaStatelessExecutor.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -85,16 +86,24 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}

// Decode this token into text
decoder.Add(id);


+ 34
- 5
LLama/Native/LLamaTokenDataArray.cs View File

@@ -46,14 +46,41 @@ namespace LLama.Native
return new LLamaTokenDataArray(candidates);
}

/// <summary>
/// Overwrite the logit values for all given tokens
/// </summary>
/// <param name="values">tuples of token and logit value to overwrite</param>
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
{
if (values.Length == 0)
return;

var dataSpan = data.Span;
foreach (var (token, value) in values)
{
for (var i = 0; i < data.Length; i++)
{
if (dataSpan[i].id == token)
{
dataSpan[i].logit = value;
break;
}
}
}
sorted = false;
}

#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
{
if (grammar == null)
return;

using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
@@ -145,15 +172,17 @@ namespace LLama.Native
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
using (var last_tokens_handle = last_tokens.Pin())
{
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
fixed (int* last_tokens_handle = last_tokens)
{
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
}
}
}
}


+ 128
- 0
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -0,0 +1,128 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
/// </summary>
public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index];
_savedLogits[i] = (index, value);
}

// Process raw logits
ProcessLogits(ctx, logits, lastTokens);

// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);

// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);

// Process token data array
ProcessTokenDataArray(ctx, candidates, lastTokens);

// Choose the final value
return ChooseToken(ctx, candidates);
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}

#region protected tokens
/// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="logits"></param>
protected void RestoreProtectedTokens(Span<float> logits)
{
if (_savedLogits == null)
return;

// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
}

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="candidates"></param>
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
{
if (_savedLogits == null || _savedLogits.Length == 0)
return;

candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
}
#endregion

/// <summary>
/// Process the raw logit values
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Process the LLamaTokenDataArray and select a single token
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Choose the final token from the candidates
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <returns></returns>
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);

/// <inheritdoc/>
public virtual void Reset()
{
}

/// <inheritdoc/>
public virtual void Dispose()
{
GC.SuppressFinalize(this);
}
}

+ 149
- 0
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -0,0 +1,149 @@
using System;
using System.Collections.Generic;
using LLama.Extensions;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
/// </summary>
public sealed class DefaultSamplingPipeline
: BaseSamplingPipeline
{
/// <summary>
/// Bias values to add to certain logits
/// </summary>
public Dictionary<int, float> LogitBias { get; } = new();

/// <summary>
/// Grammar to constrain valid tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary>
public float RepeatPenalty { get; set; } = 1.1f;

/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float AlphaFrequency
{
get => _alphaFreq;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
}
}
private float _alphaFreq = 0.1f;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float AlphaPresence
{
get => _alphaPresence;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
}
}
private float _alphaPresence = 0.1f;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
/// </summary>
public float Temperature { get; set; } = 0.75f;

/// <summary>
/// Number of tokens to keep in TopK sampling
/// </summary>
public int TopK { get; set; }

/// <summary>
/// Z value for tail free sampling
/// </summary>
public float TailFreeZ { get; set; }

/// <summary>
/// P value for locally typical sampling
/// </summary>
public float TypicalP { get; set; }

/// <summary>
/// P value for TopP sampling
/// </summary>
public float TopP { get; set; } = 1f;

/// <summary>
/// P value for MinP sampling
/// </summary>
public float MinP { get; set; }

/// <summary>
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
/// </summary>
public bool PenalizeNewline { get; set; } = false;

private readonly int[] _newlineToken = new int[1];

/// <inheritdoc />
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
{
if (PenalizeNewline)
return Array.Empty<int>();

_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}

/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
foreach (var (key, value) in LogitBias)
logits[key] += value;
}

/// <inheritdoc />
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);

// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);

// Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar);
candidates.TopK(ctx, TopK);
candidates.TailFree(ctx, TailFreeZ);
candidates.LocallyTypical(ctx, TypicalP);
candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);

Grammar?.AcceptToken(ctx, id);
return id;
}

/// <inheritdoc />
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}
}

+ 61
- 0
LLama/Sampling/ISamplingPipeline.cs View File

@@ -0,0 +1,61 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
/// </summary>
public interface ISamplingPipeline
: IDisposable
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Reset all internal state of the sampling pipeline
/// </summary>
void Reset();
}

/// <summary>
/// Extensions methods for ISamplingPipeline
/// </summary>
public static class ISamplingPipelineExtensions
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="pipeline"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
}
finally
{
ArrayPool<int>.Shared.Return(copy);
}
#endif
}
}

Loading…
Cancel
Save