- Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it existstags/0.9.1
| @@ -1,6 +1,9 @@ | |||||
| using LLama.Common; | |||||
| #nullable enable | |||||
| using LLama.Common; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Web.Common | namespace LLama.Web.Common | ||||
| { | { | ||||
| @@ -64,6 +67,9 @@ namespace LLama.Web.Common | |||||
| /// <summary> | /// <summary> | ||||
| /// A grammar to constrain possible tokens | /// A grammar to constrain possible tokens | ||||
| /// </summary> | /// </summary> | ||||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||||
| /// <inheritdoc /> | |||||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Abstractions | namespace LLama.Abstractions | ||||
| { | { | ||||
| @@ -108,5 +109,10 @@ namespace LLama.Abstractions | |||||
| /// Grammar to constrain possible tokens | /// Grammar to constrain possible tokens | ||||
| /// </summary> | /// </summary> | ||||
| SafeLLamaGrammarHandle? Grammar { get; set; } | SafeLLamaGrammarHandle? Grammar { get; set; } | ||||
| /// <summary> | |||||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | |||||
| /// </summary> | |||||
| ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| namespace LLama.Common | namespace LLama.Common | ||||
| { | { | ||||
| @@ -76,6 +77,9 @@ namespace LLama.Common | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | public SafeLLamaGrammarHandle? Grammar { get; set; } | ||||
| /// <inheritdoc /> | |||||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -10,6 +10,7 @@ using LLama.Common; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Sampling; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -212,6 +213,17 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Sample a single token from this context, using the given sampling pipeline | |||||
| /// </summary> | |||||
| /// <param name="pipeline">The pipeline to use to process the logits and to select a token</param> | |||||
| /// <param name="lastTokens">The tokens recently returned from the model</param> | |||||
| /// <returns>The selected token</returns> | |||||
| public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens) | |||||
| { | |||||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Perform the sampling. Please don't use it unless you fully know what it does. | /// Perform the sampling. Please don't use it unless you fully know what it does. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -210,16 +210,24 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| var mu = MirostatMu; | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| } | |||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| @@ -189,16 +189,24 @@ namespace LLama | |||||
| SaveSessionFile(_pathSession); | SaveSessionFile(_pathSession); | ||||
| } | } | ||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| var mu = MirostatMu; | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| MirostatMu = mu; | |||||
| } | |||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling; | |||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -85,16 +86,24 @@ namespace LLama | |||||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | ||||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | ||||
| { | { | ||||
| // Penalize the generated tokens by various penalties | |||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| // Sample a single token | |||||
| var id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| llama_token id; | |||||
| if (inferenceParams.SamplingPipeline is not null) | |||||
| { | |||||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Penalize the generated tokens by various penalties | |||||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||||
| // Sample a single token | |||||
| id = Context.Sample( | |||||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||||
| inferenceParams.MinP | |||||
| ); | |||||
| } | |||||
| // Decode this token into text | // Decode this token into text | ||||
| decoder.Add(id); | decoder.Add(id); | ||||
| @@ -1,5 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Sampling.Logits; | using LLama.Sampling.Logits; | ||||
| using LLama.Sampling.Selection; | using LLama.Sampling.Selection; | ||||
| @@ -16,9 +18,9 @@ public interface ISamplingPipeline | |||||
| /// <summary> | /// <summary> | ||||
| /// Sample a single token from the given logits | /// Sample a single token from the given logits | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | |||||
| /// <param name="logits"></param> | |||||
| /// <param name="lastTokens"></param> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="logits">The logits produced by the model</param> | |||||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | ||||
| @@ -28,10 +30,43 @@ public interface ISamplingPipeline | |||||
| void Reset(); | void Reset(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Extensions methods for ISamplingPipeline | |||||
| /// </summary> | |||||
| public static class ISamplingPipelineExtensions | |||||
| { | |||||
| /// <summary> | |||||
| /// Sample a single token from the given logits | |||||
| /// </summary> | |||||
| /// <param name="pipeline"></param> | |||||
| /// <param name="ctx">The context being sampled from</param> | |||||
| /// <param name="logits">The logits produced by the model</param> | |||||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||||
| /// <returns></returns> | |||||
| public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens) | |||||
| { | |||||
| #if NET5_0_OR_GREATER | |||||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||||
| return pipeline.Sample(ctx, logits, span); | |||||
| #else | |||||
| var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count); | |||||
| try | |||||
| { | |||||
| lastTokens.CopyTo(copy); | |||||
| return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<int>.Shared.Return(copy); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Simple implementation of `ISamplingPipeline`, applies processors in order every time | /// Simple implementation of `ISamplingPipeline`, applies processors in order every time | ||||
| /// </summary> | /// </summary> | ||||
| public sealed class BasicSamplingPipeline | |||||
| public sealed class ConfigurableSamplingPipeline | |||||
| : ISamplingPipeline | : ISamplingPipeline | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||