diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 6a63ccc3..182ace00 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -1,5 +1,6 @@ using System.Text; using LLama.Abstractions; +using LLama.Native; namespace LLama.Web.Common { @@ -118,6 +119,24 @@ namespace LLama.Web.Common /// public float? RopeFrequencyScale { get; set; } + /// + public float? YarnExtrapolationFactor { get; set; } + + /// + public float? YarnAttentionFactor { get; set; } + + /// + public float? YarnBetaFast { get; set; } + + /// + public float? YarnBetaSlow { get; set; } + + /// + public uint? YarnOriginalContext { get; set; } + + /// + public RopeScalingType? YarnScalingType { get; set; } + /// /// Use experimental mul_mat_q kernels /// diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index 8ff6d7cc..0f129217 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -1,4 +1,5 @@ using System.Text; +using LLama.Native; namespace LLama.Abstractions; @@ -67,4 +68,34 @@ public interface IContextParams /// Number of threads to use for batch processing (null = autodetect) (n_threads) /// uint? BatchThreads { get; set; } + + /// + /// YaRN extrapolation mix factor + /// + float? YarnExtrapolationFactor { get; set; } + + /// + /// YaRN magnitude scaling factor + /// + float? YarnAttentionFactor { get; set; } + + /// + /// YaRN low correction dim + /// + float? YarnBetaFast { get; set; } + + /// + /// YaRN high correction dim + /// + float? YarnBetaSlow { get; set; } + + /// + /// YaRN original context length + /// + uint? YarnOriginalContext { get; set; } + + /// + /// YaRN scaling method to use. + /// + RopeScalingType? YarnScalingType { get; set; } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index ee5bd3e4..dd4584e3 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -3,6 +3,7 @@ using System; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using LLama.Native; namespace LLama.Common { @@ -70,6 +71,7 @@ namespace LLama.Common /// public uint? BatchThreads { get; set; } + /// /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) /// @@ -98,10 +100,28 @@ namespace LLama.Common /// public float? RopeFrequencyScale { get; set; } - /// - /// Use experimental mul_mat_q kernels - /// - public bool MulMatQ { get; set; } + /// + public float? YarnExtrapolationFactor { get; set; } + + /// + public float? YarnAttentionFactor { get; set; } + + /// + public float? YarnBetaFast { get; set; } + + /// + public float? YarnBetaSlow { get; set; } + + /// + public uint? YarnOriginalContext { get; set; } + + /// + public RopeScalingType? YarnScalingType { get; set; } + + /// + /// Use experimental mul_mat_q kernels + /// + public bool MulMatQ { get; set; } /// /// Load vocab only (no weights) diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index fcc9d372..16716b53 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -29,6 +29,15 @@ namespace LLama.Extensions result.embedding = @params.EmbeddingMode; result.rope_freq_base = @params.RopeFrequencyBase ?? 0; result.rope_freq_scale = @params.RopeFrequencyScale ?? 0; + + // Default YaRN values copied from here: https://github.com/ggerganov/llama.cpp/blob/381efbf480959bb6d1e247a8b0c2328f22e350f8/common/common.h#L67 + result.yarn_ext_factor = @params.YarnExtrapolationFactor ?? -1f; + result.yarn_attn_factor = @params.YarnAttentionFactor ?? 1f; + result.yarn_beta_fast = @params.YarnBetaFast ?? 32f; + result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f; + result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0; + result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED; + result.mul_mat_q = @params.MulMatQ; result.n_threads = Threads(@params.Threads); diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 0a397a3d..f1ba569d 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -44,13 +44,13 @@ namespace LLama.Native /// /// RoPE scaling type, from `enum llama_rope_scaling_type` /// - public sbyte rope_scaling_type; + public RopeScalingType rope_scaling_type; /// /// RoPE base frequency, 0 = from model /// - public float rope_freq_base; + public float rope_freq_base; /// /// RoPE frequency scaling factor, 0 = from model /// diff --git a/LLama/Native/RopeScalingType.cs b/LLama/Native/RopeScalingType.cs new file mode 100644 index 00000000..435932e8 --- /dev/null +++ b/LLama/Native/RopeScalingType.cs @@ -0,0 +1,17 @@ +namespace LLama.Native +{ + /// + /// RoPE scaling type. C# equivalent of llama_rope_scaling_type + /// + public enum RopeScalingType + : sbyte + { + LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + + LLAMA_ROPE_SCALING_NONE = 0, + + LLAMA_ROPE_SCALING_LINEAR = 1, + + LLAMA_ROPE_SCALING_YARN = 2, + } +}