Browse Source

- Added support for the MinP sampler

- Cleaned up comments in implementations of `IInferenceParams`
 - Removed default values for all parameters in `LLamaContext.Sample` - they're never used and probably _shouldn't_ ever be used
tags/v0.8.1
Martin Evans 2 years ago
parent
commit
d743516070
8 changed files with 108 additions and 131 deletions
  1. +37
    -69
      LLama.Web/Common/InferenceOptions.cs
  2. +9
    -5
      LLama/Abstractions/IInferenceParams.cs
  3. +32
    -47
      LLama/Common/InferenceParams.cs
  4. +5
    -3
      LLama/LLamaContext.cs
  5. +2
    -2
      LLama/LLamaInstructExecutor.cs
  6. +3
    -3
      LLama/LLamaInteractExecutor.cs
  7. +5
    -2
      LLama/LLamaStatelessExecutor.cs
  8. +15
    -0
      LLama/Native/LLamaTokenDataArray.cs

+ 37
- 69
LLama.Web/Common/InferenceOptions.cs View File

@@ -4,93 +4,61 @@ using LLama.Native;

namespace LLama.Web.Common
{
public class InferenceOptions : IInferenceParams
public class InferenceOptions
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// </summary>
/// <inheritdoc />
public int TokensKeep { get; set; } = 0;
/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>

/// <inheritdoc />
public int MaxTokens { get; set; } = -1;
/// <summary>
/// logit bias for specific tokens
/// </summary>

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>
public string PathSession { get; set; } = string.Empty;
/// <summary>
/// string to suffix user inputs with
/// </summary>
public string InputSuffix { get; set; } = string.Empty;
/// <summary>
/// string to prefix user inputs with
/// </summary>
public string InputPrefix { get; set; } = string.Empty;
/// <summary>
/// 0 or lower to use vocab size
/// </summary>

/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>


+ 9
- 5
LLama/Abstractions/IInferenceParams.cs View File

@@ -25,7 +25,6 @@ namespace LLama.Abstractions
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }


/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
@@ -41,10 +40,15 @@ namespace LLama.Abstractions
/// </summary>
public float TopP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }
/// <summary>llama_eval
/// 0.0 = disabled
/// </summary>
public float MinP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }

/// <summary>
/// 1.0 = disabled


+ 32
- 47
LLama/Common/InferenceParams.cs View File

@@ -6,10 +6,12 @@ using LLama.Native;
namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
public record InferenceParams : IInferenceParams
public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
@@ -30,66 +32,49 @@ namespace LLama.Common
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

/// <summary>
/// 0 or lower to use vocab size
/// </summary>
/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }
}



+ 5
- 3
LLama/LLamaContext.cs View File

@@ -226,10 +226,11 @@ namespace LLama
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;

@@ -264,6 +265,7 @@ namespace LLama
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.MinP(NativeHandle, minP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}


+ 2
- 2
LLama/LLamaInstructExecutor.cs View File

@@ -216,8 +216,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;



+ 3
- 3
LLama/LLamaInteractExecutor.cs View File

@@ -194,9 +194,9 @@ namespace LLama

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;



+ 5
- 2
LLama/LLamaStatelessExecutor.cs View File

@@ -90,8 +90,11 @@ namespace LLama
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);

// Decode this token into text
decoder.Add(id);


+ 15
- 0
LLama/Native/LLamaTokenDataArray.cs View File

@@ -91,6 +91,21 @@ namespace LLama.Native
}
}

/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
/// <param name="context"></param>
/// <param name="p">All tokens with probability greater than this will be kept</param>
/// <param name="minKeep"></param>
public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>


Loading…
Cancel
Save