Browse Source

Exposed YaRN scaling parameters in IContextParams

tags/v0.8.0
Martin Evans 2 years ago
parent
commit
04ee64a6be
6 changed files with 102 additions and 6 deletions
  1. +19
    -0
      LLama.Web/Common/ModelOptions.cs
  2. +31
    -0
      LLama/Abstractions/IContextParams.cs
  3. +24
    -4
      LLama/Common/ModelParams.cs
  4. +9
    -0
      LLama/Extensions/IContextParamsExtensions.cs
  5. +2
    -2
      LLama/Native/LLamaContextParams.cs
  6. +17
    -0
      LLama/Native/RopeScalingType.cs

+ 19
- 0
LLama.Web/Common/ModelOptions.cs View File

@@ -1,5 +1,6 @@
using System.Text;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Web.Common
{
@@ -118,6 +119,24 @@ namespace LLama.Web.Common
/// </summary>
public float? RopeFrequencyScale { get; set; }

/// <inheritdoc />
public float? YarnExtrapolationFactor { get; set; }

/// <inheritdoc />
public float? YarnAttentionFactor { get; set; }

/// <inheritdoc />
public float? YarnBetaFast { get; set; }

/// <inheritdoc />
public float? YarnBetaSlow { get; set; }

/// <inheritdoc />
public uint? YarnOriginalContext { get; set; }

/// <inheritdoc />
public RopeScalingType? YarnScalingType { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>


+ 31
- 0
LLama/Abstractions/IContextParams.cs View File

@@ -1,4 +1,5 @@
using System.Text;
using LLama.Native;

namespace LLama.Abstractions;

@@ -67,4 +68,34 @@ public interface IContextParams
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
uint? BatchThreads { get; set; }

/// <summary>
/// YaRN extrapolation mix factor
/// </summary>
float? YarnExtrapolationFactor { get; set; }

/// <summary>
/// YaRN magnitude scaling factor
/// </summary>
float? YarnAttentionFactor { get; set; }

/// <summary>
/// YaRN low correction dim
/// </summary>
float? YarnBetaFast { get; set; }

/// <summary>
/// YaRN high correction dim
/// </summary>
float? YarnBetaSlow { get; set; }

/// <summary>
/// YaRN original context length
/// </summary>
uint? YarnOriginalContext { get; set; }

/// <summary>
/// YaRN scaling method to use.
/// </summary>
RopeScalingType? YarnScalingType { get; set; }
}

+ 24
- 4
LLama/Common/ModelParams.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;

namespace LLama.Common
{
@@ -70,6 +71,7 @@ namespace LLama.Common
/// </summary>
public uint? BatchThreads { get; set; }


/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
@@ -98,10 +100,28 @@ namespace LLama.Common
/// </summary>
public float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>
public bool MulMatQ { get; set; }
/// <inheritdoc />
public float? YarnExtrapolationFactor { get; set; }

/// <inheritdoc />
public float? YarnAttentionFactor { get; set; }

/// <inheritdoc />
public float? YarnBetaFast { get; set; }

/// <inheritdoc />
public float? YarnBetaSlow { get; set; }

/// <inheritdoc />
public uint? YarnOriginalContext { get; set; }

/// <inheritdoc />
public RopeScalingType? YarnScalingType { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>
public bool MulMatQ { get; set; }

/// <summary>
/// Load vocab only (no weights)


+ 9
- 0
LLama/Extensions/IContextParamsExtensions.cs View File

@@ -29,6 +29,15 @@ namespace LLama.Extensions
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;

// Default YaRN values copied from here: https://github.com/ggerganov/llama.cpp/blob/381efbf480959bb6d1e247a8b0c2328f22e350f8/common/common.h#L67
result.yarn_ext_factor = @params.YarnExtrapolationFactor ?? -1f;
result.yarn_attn_factor = @params.YarnAttentionFactor ?? 1f;
result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;

result.mul_mat_q = @params.MulMatQ;

result.n_threads = Threads(@params.Threads);


+ 2
- 2
LLama/Native/LLamaContextParams.cs View File

@@ -44,13 +44,13 @@ namespace LLama.Native
/// <summary>
/// RoPE scaling type, from `enum llama_rope_scaling_type`
/// </summary>
public sbyte rope_scaling_type;
public RopeScalingType rope_scaling_type;

/// <summary>
/// RoPE base frequency, 0 = from model
/// </summary>
public float rope_freq_base;
public float rope_freq_base;
/// <summary>
/// RoPE frequency scaling factor, 0 = from model
/// </summary>


+ 17
- 0
LLama/Native/RopeScalingType.cs View File

@@ -0,0 +1,17 @@
namespace LLama.Native
{
/// <summary>
/// RoPE scaling type. C# equivalent of llama_rope_scaling_type
/// </summary>
public enum RopeScalingType
: sbyte
{
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,

LLAMA_ROPE_SCALING_NONE = 0,

LLAMA_ROPE_SCALING_LINEAR = 1,

LLAMA_ROPE_SCALING_YARN = 2,
}
}

Loading…
Cancel
Save