- Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it existstags/0.9.1
| @@ -1,6 +1,9 @@ | |||
| using LLama.Common; | |||
| #nullable enable | |||
| using LLama.Common; | |||
| using LLama.Abstractions; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Web.Common | |||
| { | |||
| @@ -64,6 +67,9 @@ namespace LLama.Web.Common | |||
| /// <summary> | |||
| /// A grammar to constrain possible tokens | |||
| /// </summary> | |||
| public SafeLLamaGrammarHandle Grammar { get; set; } = null; | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <inheritdoc /> | |||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Abstractions | |||
| { | |||
| @@ -108,5 +109,10 @@ namespace LLama.Abstractions | |||
| /// Grammar to constrain possible tokens | |||
| /// </summary> | |||
| SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <summary> | |||
| /// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b> | |||
| /// </summary> | |||
| ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -76,6 +77,9 @@ namespace LLama.Common | |||
| /// <inheritdoc /> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| /// <inheritdoc /> | |||
| public ISamplingPipeline? SamplingPipeline { get; set; } | |||
| } | |||
| /// <summary> | |||
| @@ -10,6 +10,7 @@ using LLama.Common; | |||
| using System.Runtime.InteropServices; | |||
| using LLama.Extensions; | |||
| using LLama.Abstractions; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| @@ -212,6 +213,17 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sample a single token from this context, using the given sampling pipeline | |||
| /// </summary> | |||
| /// <param name="pipeline">The pipeline to use to process the logits and to select a token</param> | |||
| /// <param name="lastTokens">The tokens recently returned from the model</param> | |||
| /// <returns>The selected token</returns> | |||
| public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens) | |||
| { | |||
| return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| /// <summary> | |||
| /// Perform the sampling. Please don't use it unless you fully know what it does. | |||
| /// </summary> | |||
| @@ -210,16 +210,24 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| var mu = MirostatMu; | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -189,16 +189,24 @@ namespace LLama | |||
| SaveSessionFile(_pathSession); | |||
| } | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostatMu; | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| } | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Native; | |||
| using LLama.Sampling; | |||
| using Microsoft.Extensions.Logging; | |||
| namespace LLama | |||
| @@ -85,16 +86,24 @@ namespace LLama | |||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| llama_token id; | |||
| if (inferenceParams.SamplingPipeline is not null) | |||
| { | |||
| id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); | |||
| } | |||
| else | |||
| { | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| } | |||
| // Decode this token into text | |||
| decoder.Add(id); | |||
| @@ -1,5 +1,7 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using LLama.Native; | |||
| using LLama.Sampling.Logits; | |||
| using LLama.Sampling.Selection; | |||
| @@ -16,9 +18,9 @@ public interface ISamplingPipeline | |||
| /// <summary> | |||
| /// Sample a single token from the given logits | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="logits"></param> | |||
| /// <param name="lastTokens"></param> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A span of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens); | |||
| @@ -28,10 +30,43 @@ public interface ISamplingPipeline | |||
| void Reset(); | |||
| } | |||
| /// <summary> | |||
| /// Extensions methods for ISamplingPipeline | |||
| /// </summary> | |||
| public static class ISamplingPipelineExtensions | |||
| { | |||
| /// <summary> | |||
| /// Sample a single token from the given logits | |||
| /// </summary> | |||
| /// <param name="pipeline"></param> | |||
| /// <param name="ctx">The context being sampled from</param> | |||
| /// <param name="logits">The logits produced by the model</param> | |||
| /// <param name="lastTokens">A list of tokens recently returned by the model</param> | |||
| /// <returns></returns> | |||
| public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens) | |||
| { | |||
| #if NET5_0_OR_GREATER | |||
| var span = CollectionsMarshal.AsSpan(lastTokens); | |||
| return pipeline.Sample(ctx, logits, span); | |||
| #else | |||
| var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count); | |||
| try | |||
| { | |||
| lastTokens.CopyTo(copy); | |||
| return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<int>.Shared.Return(copy); | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Simple implementation of `ISamplingPipeline`, applies processors in order every time | |||
| /// </summary> | |||
| public sealed class BasicSamplingPipeline | |||
| public sealed class ConfigurableSamplingPipeline | |||
| : ISamplingPipeline | |||
| { | |||
| /// <summary> | |||