- Cleaned up comments in implementations of `IInferenceParams` - Removed default values for all parameters in `LLamaContext.Sample` - they're never used and probably _shouldn't_ ever be usedtags/v0.8.1
| @@ -4,93 +4,61 @@ using LLama.Native; | |||
| namespace LLama.Web.Common | |||
| { | |||
| public class InferenceOptions : IInferenceParams | |||
| public class InferenceOptions | |||
| : IInferenceParams | |||
| { | |||
| /// <summary> | |||
| /// number of tokens to keep from initial prompt | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int TokensKeep { get; set; } = 0; | |||
| /// <summary> | |||
| /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response | |||
| /// until it complete. | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int MaxTokens { get; set; } = -1; | |||
| /// <summary> | |||
| /// logit bias for specific tokens | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public Dictionary<int, float>? LogitBias { get; set; } = null; | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// path to file for saving/loading model eval state | |||
| /// </summary> | |||
| public string PathSession { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// string to suffix user inputs with | |||
| /// </summary> | |||
| public string InputSuffix { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// string to prefix user inputs with | |||
| /// </summary> | |||
| public string InputPrefix { get; set; } = string.Empty; | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int TopK { get; set; } = 40; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float TopP { get; set; } = 0.95f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MinP { get; set; } = 0.05f; | |||
| /// <inheritdoc /> | |||
| public float TfsZ { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float TypicalP { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float Temperature { get; set; } = 0.8f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int RepeatLastTokensCount { get; set; } = 64; | |||
| /// <summary> | |||
| /// frequency penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float FrequencyPenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// presence penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float PresencePenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// Mirostat uses tokens instead of words. | |||
| /// algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public MirostatType Mirostat { get; set; } = MirostatType.Disable; | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MirostatTau { get; set; } = 5.0f; | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MirostatEta { get; set; } = 0.1f; | |||
| /// <summary> | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public bool PenalizeNL { get; set; } = true; | |||
| /// <summary> | |||
| @@ -25,7 +25,6 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| public Dictionary<int, float>? LogitBias { get; set; } | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| @@ -41,10 +40,15 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| public float TopP { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TfsZ { get; set; } | |||
| /// <summary>llama_eval | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float MinP { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TfsZ { get; set; } | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| @@ -6,10 +6,12 @@ using LLama.Native; | |||
| namespace LLama.Common | |||
| { | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// The paramters used for inference. | |||
| /// </summary> | |||
| public record InferenceParams : IInferenceParams | |||
| public record InferenceParams | |||
| : IInferenceParams | |||
| { | |||
| /// <summary> | |||
| /// number of tokens to keep from initial prompt | |||
| @@ -30,66 +32,49 @@ namespace LLama.Common | |||
| /// </summary> | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int TopK { get; set; } = 40; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float TopP { get; set; } = 0.95f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MinP { get; set; } = 0.05f; | |||
| /// <inheritdoc /> | |||
| public float TfsZ { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float TypicalP { get; set; } = 1.0f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float Temperature { get; set; } = 0.8f; | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float RepeatPenalty { get; set; } = 1.1f; | |||
| /// <summary> | |||
| /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public int RepeatLastTokensCount { get; set; } = 64; | |||
| /// <summary> | |||
| /// frequency penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float FrequencyPenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// presence penalty coefficient | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float PresencePenalty { get; set; } = .0f; | |||
| /// <summary> | |||
| /// Mirostat uses tokens instead of words. | |||
| /// algorithm described in the paper https://arxiv.org/abs/2007.14966. | |||
| /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public MirostatType Mirostat { get; set; } = MirostatType.Disable; | |||
| /// <summary> | |||
| /// target entropy | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MirostatTau { get; set; } = 5.0f; | |||
| /// <summary> | |||
| /// learning rate | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public float MirostatEta { get; set; } = 0.1f; | |||
| /// <summary> | |||
| /// consider newlines as a repeatable token (penalize_nl) | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public bool PenalizeNL { get; set; } = true; | |||
| /// <summary> | |||
| /// A grammar to constrain the possible tokens | |||
| /// </summary> | |||
| /// <inheritdoc /> | |||
| public SafeLLamaGrammarHandle? Grammar { get; set; } | |||
| } | |||
| @@ -226,10 +226,11 @@ namespace LLama | |||
| /// <param name="tfsZ"></param> | |||
| /// <param name="typicalP"></param> | |||
| /// <param name="grammar"></param> | |||
| /// <param name="minP"></param> | |||
| /// <returns></returns> | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, | |||
| SafeLLamaGrammarHandle? grammar = null) | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, | |||
| float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, | |||
| SafeLLamaGrammarHandle? grammar, float minP) | |||
| { | |||
| llama_token id; | |||
| @@ -264,6 +265,7 @@ namespace LLama | |||
| candidates.TailFree(NativeHandle, tfsZ); | |||
| candidates.LocallyTypical(NativeHandle, typicalP); | |||
| candidates.TopP(NativeHandle, topP); | |||
| candidates.MinP(NativeHandle, minP); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleToken(NativeHandle); | |||
| } | |||
| @@ -216,8 +216,8 @@ namespace LLama | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||
| inferenceParams.Grammar | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| @@ -194,9 +194,9 @@ namespace LLama | |||
| var mu = MirostatMu; | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, | |||
| inferenceParams.Grammar | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| MirostatMu = mu; | |||
| @@ -90,8 +90,11 @@ namespace LLama | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); | |||
| var id = Context.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, | |||
| inferenceParams.MinP | |||
| ); | |||
| // Decode this token into text | |||
| decoder.Add(id); | |||
| @@ -91,6 +91,21 @@ namespace LLama.Native | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p">All tokens with probability greater than this will be kept</param> | |||
| /// <param name="minKeep"></param> | |||
| public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_min_p(context, ref st, p, minKeep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |||
| /// </summary> | |||