diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 4f8f97e6..3d88d6b3 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -1,15 +1,115 @@
-using LLama.Common;
+using LLama.Abstractions;
namespace LLama.Web.Common
{
- public class ModelOptions : ModelParams
+ public class ModelOptions : IModelParams
{
- public ModelOptions() : base("", 512, 20, 1337, true, true, false, false, "", "", -1, 512, false, false)
- {
- }
-
+
public string Name { get; set; }
public int MaxInstances { get; set; }
- }
+
+ ///
+ /// Model context size (n_ctx)
+ ///
+ public int ContextSize { get; set; } = 512;
+ ///
+ /// the GPU that is used for scratch and small tensors
+ ///
+ public int MainGpu { get; set; } = 0;
+ ///
+ /// if true, reduce VRAM usage at the cost of performance
+ ///
+ public bool LowVram { get; set; } = false;
+ ///
+ /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+ ///
+ public int GpuLayerCount { get; set; } = 20;
+ ///
+ /// Seed for the random number generator (seed)
+ ///
+ public int Seed { get; set; } = 1686349486;
+ ///
+ /// Use f16 instead of f32 for memory kv (memory_f16)
+ ///
+ public bool UseFp16Memory { get; set; } = true;
+ ///
+ /// Use mmap for faster loads (use_mmap)
+ ///
+ public bool UseMemorymap { get; set; } = true;
+ ///
+ /// Use mlock to keep model in memory (use_mlock)
+ ///
+ public bool UseMemoryLock { get; set; } = false;
+ ///
+ /// Compute perplexity over the prompt (perplexity)
+ ///
+ public bool Perplexity { get; set; } = false;
+ ///
+ /// Model path (model)
+ ///
+ public string ModelPath { get; set; }
+ ///
+ /// model alias
+ ///
+ public string ModelAlias { get; set; } = "unknown";
+ ///
+ /// lora adapter path (lora_adapter)
+ ///
+ public string LoraAdapter { get; set; } = string.Empty;
+ ///
+ /// base model path for the lora adapter (lora_base)
+ ///
+ public string LoraBase { get; set; } = string.Empty;
+ ///
+ /// Number of threads (-1 = autodetect) (n_threads)
+ ///
+ public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
+ ///
+ /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ ///
+ public int BatchSize { get; set; } = 512;
+
+ ///
+ /// Whether to convert eos to newline during the inference.
+ ///
+ public bool ConvertEosToNewLine { get; set; } = false;
+
+ ///
+ /// Whether to use embedding mode. (embedding) Note that if this is set to true,
+ /// The LLamaModel won't produce text response anymore.
+ ///
+ public bool EmbeddingMode { get; set; } = false;
+
+ ///
+ /// how split tensors should be distributed across GPUs
+ ///
+ public nint TensorSplits { get; set; }
+
+ ///
+ /// Grouped-Query Attention
+ ///
+ public int GroupedQueryAttention { get; set; } = 1;
+
+ ///
+ /// RMS Norm Epsilon
+ ///
+ public float RmsNormEpsilon { get; set; } = 5e-6f;
+
+ ///
+ /// RoPE base frequency
+ ///
+ public float RopeFrequencyBase { get; set; } = 10000.0f;
+
+ ///
+ /// RoPE frequency scaling factor
+ ///
+ public float RopeFrequencyScale { get; set; } = 1.0f;
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ public bool MulMatQ { get; set; }
+
+ }
}
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
new file mode 100644
index 00000000..40c5432b
--- /dev/null
+++ b/LLama/Abstractions/IModelParams.cs
@@ -0,0 +1,123 @@
+using System;
+
+namespace LLama.Abstractions
+{
+ public interface IModelParams
+ {
+ ///
+ /// Model context size (n_ctx)
+ ///
+ int ContextSize { get; set; }
+
+ ///
+ /// the GPU that is used for scratch and small tensors
+ ///
+ int MainGpu { get; set; }
+
+ ///
+ /// if true, reduce VRAM usage at the cost of performance
+ ///
+ bool LowVram { get; set; }
+
+ ///
+ /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+ ///
+ int GpuLayerCount { get; set; }
+
+ ///
+ /// Seed for the random number generator (seed)
+ ///
+ int Seed { get; set; }
+
+ ///
+ /// Use f16 instead of f32 for memory kv (memory_f16)
+ ///
+ bool UseFp16Memory { get; set; }
+
+ ///
+ /// Use mmap for faster loads (use_mmap)
+ ///
+ bool UseMemorymap { get; set; }
+
+ ///
+ /// Use mlock to keep model in memory (use_mlock)
+ ///
+ bool UseMemoryLock { get; set; }
+
+ ///
+ /// Compute perplexity over the prompt (perplexity)
+ ///
+ bool Perplexity { get; set; }
+
+ ///
+ /// Model path (model)
+ ///
+ string ModelPath { get; set; }
+
+ ///
+ /// model alias
+ ///
+ string ModelAlias { get; set; }
+
+ ///
+ /// lora adapter path (lora_adapter)
+ ///
+ string LoraAdapter { get; set; }
+
+ ///
+ /// base model path for the lora adapter (lora_base)
+ ///
+ string LoraBase { get; set; }
+
+ ///
+ /// Number of threads (-1 = autodetect) (n_threads)
+ ///
+ int Threads { get; set; }
+
+ ///
+ /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ ///
+ int BatchSize { get; set; }
+
+ ///
+ /// Whether to convert eos to newline during the inference.
+ ///
+ bool ConvertEosToNewLine { get; set; }
+
+ ///
+ /// Whether to use embedding mode. (embedding) Note that if this is set to true,
+ /// The LLamaModel won't produce text response anymore.
+ ///
+ bool EmbeddingMode { get; set; }
+
+ ///
+ /// how split tensors should be distributed across GPUs
+ ///
+ nint TensorSplits { get; set; }
+
+ ///
+ /// Grouped-Query Attention
+ ///
+ int GroupedQueryAttention { get; set; }
+
+ ///
+ /// RMS Norm Epsilon
+ ///
+ float RmsNormEpsilon { get; set; }
+
+ ///
+ /// RoPE base frequency
+ ///
+ float RopeFrequencyBase { get; set; }
+
+ ///
+ /// RoPE frequency scaling factor
+ ///
+ float RopeFrequencyScale { get; set; }
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ bool MulMatQ { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 4f72eff3..72c77937 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -1,4 +1,5 @@
-using System;
+using LLama.Abstractions;
+using System;
using System.Collections.Generic;
using System.Text;
@@ -7,7 +8,7 @@ namespace LLama.Common
///
/// The parameters for initializing a LLama model.
///
- public class ModelParams
+ public class ModelParams : IModelParams
{
///
/// Model context size (n_ctx)
@@ -86,28 +87,59 @@ namespace LLama.Common
///
public nint TensorSplits { get; set; }
- ///
- ///
- ///
- /// The model path.
- /// Model context size (n_ctx)
- /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
- /// Seed for the random number generator (seed)
- /// Whether to use f16 instead of f32 for memory kv (memory_f16)
- /// Whether to use mmap for faster loads (use_mmap)
- /// Whether to use mlock to keep model in memory (use_mlock)
- /// Thether to compute perplexity over the prompt (perplexity)
- /// Lora adapter path (lora_adapter)
- /// Base model path for the lora adapter (lora_base)
- /// Number of threads (-1 = autodetect) (n_threads)
- /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
- /// Whether to convert eos to newline during the inference.
- /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.
- public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
+ ///
+ /// Grouped-Query Attention
+ ///
+ public int GroupedQueryAttention { get; set; } = 1;
+
+ ///
+ /// RMS Norm Epsilon
+ ///
+ public float RmsNormEpsilon { get; set; } = 5e-6f;
+
+ ///
+ /// RoPE base frequency
+ ///
+ public float RopeFrequencyBase { get; set; } = 10000.0f;
+
+ ///
+ /// RoPE frequency scaling factor
+ ///
+ public float RopeFrequencyScale { get; set; } = 1.0f;
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ public bool MulMatQ { get; set; }
+
+ ///
+ ///
+ ///
+ /// The model path.
+ /// Model context size (n_ctx)
+ /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+ /// Seed for the random number generator (seed)
+ /// Whether to use f16 instead of f32 for memory kv (memory_f16)
+ /// Whether to use mmap for faster loads (use_mmap)
+ /// Whether to use mlock to keep model in memory (use_mlock)
+ /// Thether to compute perplexity over the prompt (perplexity)
+ /// Lora adapter path (lora_adapter)
+ /// Base model path for the lora adapter (lora_base)
+ /// Number of threads (-1 = autodetect) (n_threads)
+ /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ /// Whether to convert eos to newline during the inference.
+ /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.
+ /// Grouped-Query Attention
+ /// RMS Norm Epsilon
+ /// RoPE base frequency.
+ /// RoPE frequency scaling factor
+ /// Use experimental mul_mat_q kernels
+ public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
int seed = 1337, bool useFp16Memory = true,
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
- bool convertEosToNewLine = false, bool embeddingMode = false)
+ bool convertEosToNewLine = false, bool embeddingMode = false,
+ int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false)
{
ContextSize = contextSize;
GpuLayerCount = gpuLayerCount;
@@ -123,6 +155,11 @@ namespace LLama.Common
BatchSize = batchSize;
ConvertEosToNewLine = convertEosToNewLine;
EmbeddingMode = embeddingMode;
- }
+ GroupedQueryAttention = gqa;
+ RmsNormEpsilon = rmsNormEps;
+ RopeFrequencyBase = rope_freq_base;
+ RopeFrequencyScale = rope_freq_scale;
+ MulMatQ = muMatQ;
+ }
}
}
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 4bbb61d2..24b6ee80 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -4,7 +4,7 @@ using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using System.Linq;
-using LLama.Common;
+using LLama.Abstractions;
namespace LLama
{
@@ -28,7 +28,7 @@ namespace LLama
///
///
///
- public LLamaEmbedder(ModelParams @params)
+ public LLamaEmbedder(IModelParams @params)
{
@params.EmbeddingMode = true;
_ctx = Utils.InitLLamaContextFromModelParams(@params);
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index d82e2f43..2bd31199 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using Microsoft.Win32.SafeHandles;
+using LLama.Abstractions;
namespace LLama
{
@@ -30,7 +31,7 @@ namespace LLama
///
/// The model params set for this model.
///
- public ModelParams Params { get; set; }
+ public IModelParams Params { get; set; }
///
/// The native handle, which is used to be passed to the native APIs. Please avoid using it
/// unless you know what is the usage of the Native API.
@@ -47,7 +48,7 @@ namespace LLama
/// Model params.
/// Encoding to deal with text input.
/// The logger.
- public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
+ public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null)
{
_logger = logger;
this.Params = Params;
diff --git a/LLama/ResettableLLamaModel.cs b/LLama/ResettableLLamaModel.cs
index f2862dc7..d9b4e822 100644
--- a/LLama/ResettableLLamaModel.cs
+++ b/LLama/ResettableLLamaModel.cs
@@ -1,4 +1,4 @@
-using LLama.Common;
+using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;
@@ -19,7 +19,7 @@ namespace LLama
///
///
///
- public ResettableLLamaModel(ModelParams Params, string encoding = "UTF-8") : base(Params, encoding)
+ public ResettableLLamaModel(IModelParams Params, string encoding = "UTF-8") : base(Params, encoding)
{
OriginalState = GetState();
}
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index e99e6b29..0371718a 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -1,4 +1,4 @@
-using LLama.Common;
+using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using System;
@@ -13,7 +13,7 @@ namespace LLama
using llama_token = Int32;
internal static class Utils
{
- public static SafeLLamaContextHandle InitLLamaContextFromModelParams(ModelParams @params)
+ public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
var lparams = NativeApi.llama_context_default_params();
@@ -28,6 +28,11 @@ namespace LLama
lparams.logits_all = @params.Perplexity;
lparams.embedding = @params.EmbeddingMode;
lparams.low_vram = @params.LowVram;
+ lparams.n_gqa = @params.GroupedQueryAttention;
+ lparams.rms_norm_eps = @params.RmsNormEpsilon;
+ lparams.rope_freq_base = @params.RopeFrequencyBase;
+ lparams.rope_freq_scale = @params.RopeFrequencyScale;
+ lparams.mul_mat_q = @params.MulMatQ;
/*
if (@params.TensorSplits.Length != 1)