Browse Source

- Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline.

- Added a `Sample` method to `LLamaContext` which uses a custom pipeline
 - Modified all executors to use the custom pipeline if it exists
tags/0.9.1
Martin Evans 1 year ago
parent
commit
b34f72a883
8 changed files with 123 additions and 35 deletions
  1. +8
    -2
      LLama.Web/Common/InferenceOptions.cs
  2. +6
    -0
      LLama/Abstractions/IInferenceParams.cs
  3. +4
    -0
      LLama/Common/InferenceParams.cs
  4. +12
    -0
      LLama/LLamaContext.cs
  5. +17
    -9
      LLama/LLamaInstructExecutor.cs
  6. +18
    -10
      LLama/LLamaInteractExecutor.cs
  7. +19
    -10
      LLama/LLamaStatelessExecutor.cs
  8. +39
    -4
      LLama/Sampling/ISamplingPipeline.cs

+ 8
- 2
LLama.Web/Common/InferenceOptions.cs View File

@@ -1,6 +1,9 @@
using LLama.Common;
#nullable enable

using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Web.Common
{
@@ -64,6 +67,9 @@ namespace LLama.Web.Common
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 6
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Abstractions
{
@@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 4
- 0
LLama/Common/InferenceParams.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Common
{
@@ -76,6 +77,9 @@ namespace LLama.Common

/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}

/// <summary>


+ 12
- 0
LLama/LLamaContext.cs View File

@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -212,6 +213,17 @@ namespace LLama
}
}

/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}

/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>


+ 17
- 9
LLama/LLamaInstructExecutor.cs View File

@@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 18
- 10
LLama/LLamaInteractExecutor.cs View File

@@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 19
- 10
LLama/LLamaStatelessExecutor.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -85,16 +86,24 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}

// Decode this token into text
decoder.Add(id);


+ 39
- 4
LLama/Sampling/ISamplingPipeline.cs View File

@@ -1,5 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
@@ -16,9 +18,9 @@ public interface ISamplingPipeline
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="ctx"></param>
/// <param name="logits"></param>
/// <param name="lastTokens"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);

@@ -28,10 +30,43 @@ public interface ISamplingPipeline
void Reset();
}

/// <summary>
/// Extensions methods for ISamplingPipeline
/// </summary>
public static class ISamplingPipelineExtensions
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="pipeline"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
}
finally
{
ArrayPool<int>.Shared.Return(copy);
}
#endif
}
}

/// <summary>
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
/// </summary>
public sealed class BasicSamplingPipeline
public sealed class ConfigurableSamplingPipeline
: ISamplingPipeline
{
/// <summary>


Loading…
Cancel
Save