diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 6a63ccc3..182ace00 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -1,5 +1,6 @@
using System.Text;
using LLama.Abstractions;
+using LLama.Native;
namespace LLama.Web.Common
{
@@ -118,6 +119,24 @@ namespace LLama.Web.Common
///
public float? RopeFrequencyScale { get; set; }
+ ///
+ public float? YarnExtrapolationFactor { get; set; }
+
+ ///
+ public float? YarnAttentionFactor { get; set; }
+
+ ///
+ public float? YarnBetaFast { get; set; }
+
+ ///
+ public float? YarnBetaSlow { get; set; }
+
+ ///
+ public uint? YarnOriginalContext { get; set; }
+
+ ///
+ public RopeScalingType? YarnScalingType { get; set; }
+
///
/// Use experimental mul_mat_q kernels
///
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index 8ff6d7cc..0f129217 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -1,4 +1,5 @@
using System.Text;
+using LLama.Native;
namespace LLama.Abstractions;
@@ -67,4 +68,34 @@ public interface IContextParams
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
///
uint? BatchThreads { get; set; }
+
+ ///
+ /// YaRN extrapolation mix factor
+ ///
+ float? YarnExtrapolationFactor { get; set; }
+
+ ///
+ /// YaRN magnitude scaling factor
+ ///
+ float? YarnAttentionFactor { get; set; }
+
+ ///
+ /// YaRN low correction dim
+ ///
+ float? YarnBetaFast { get; set; }
+
+ ///
+ /// YaRN high correction dim
+ ///
+ float? YarnBetaSlow { get; set; }
+
+ ///
+ /// YaRN original context length
+ ///
+ uint? YarnOriginalContext { get; set; }
+
+ ///
+ /// YaRN scaling method to use.
+ ///
+ RopeScalingType? YarnScalingType { get; set; }
}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index ee5bd3e4..dd4584e3 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -3,6 +3,7 @@ using System;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
+using LLama.Native;
namespace LLama.Common
{
@@ -70,6 +71,7 @@ namespace LLama.Common
///
public uint? BatchThreads { get; set; }
+
///
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
///
@@ -98,10 +100,28 @@ namespace LLama.Common
///
public float? RopeFrequencyScale { get; set; }
- ///
- /// Use experimental mul_mat_q kernels
- ///
- public bool MulMatQ { get; set; }
+ ///
+ public float? YarnExtrapolationFactor { get; set; }
+
+ ///
+ public float? YarnAttentionFactor { get; set; }
+
+ ///
+ public float? YarnBetaFast { get; set; }
+
+ ///
+ public float? YarnBetaSlow { get; set; }
+
+ ///
+ public uint? YarnOriginalContext { get; set; }
+
+ ///
+ public RopeScalingType? YarnScalingType { get; set; }
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ public bool MulMatQ { get; set; }
///
/// Load vocab only (no weights)
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index fcc9d372..16716b53 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -29,6 +29,15 @@ namespace LLama.Extensions
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
+
+ // Default YaRN values copied from here: https://github.com/ggerganov/llama.cpp/blob/381efbf480959bb6d1e247a8b0c2328f22e350f8/common/common.h#L67
+ result.yarn_ext_factor = @params.YarnExtrapolationFactor ?? -1f;
+ result.yarn_attn_factor = @params.YarnAttentionFactor ?? 1f;
+ result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
+ result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
+ result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
+ result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
+
result.mul_mat_q = @params.MulMatQ;
result.n_threads = Threads(@params.Threads);
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 0a397a3d..f1ba569d 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -44,13 +44,13 @@ namespace LLama.Native
///
/// RoPE scaling type, from `enum llama_rope_scaling_type`
///
- public sbyte rope_scaling_type;
+ public RopeScalingType rope_scaling_type;
///
/// RoPE base frequency, 0 = from model
///
- public float rope_freq_base;
+ public float rope_freq_base;
///
/// RoPE frequency scaling factor, 0 = from model
///
diff --git a/LLama/Native/RopeScalingType.cs b/LLama/Native/RopeScalingType.cs
new file mode 100644
index 00000000..435932e8
--- /dev/null
+++ b/LLama/Native/RopeScalingType.cs
@@ -0,0 +1,17 @@
+namespace LLama.Native
+{
+ ///
+ /// RoPE scaling type. C# equivalent of llama_rope_scaling_type
+ ///
+ public enum RopeScalingType
+ : sbyte
+ {
+ LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
+
+ LLAMA_ROPE_SCALING_NONE = 0,
+
+ LLAMA_ROPE_SCALING_LINEAR = 1,
+
+ LLAMA_ROPE_SCALING_YARN = 2,
+ }
+}