diff --git a/.gitignore b/.gitignore
index e7c87968..f7b8be30 100644
--- a/.gitignore
+++ b/.gitignore
@@ -344,4 +344,5 @@ test/TensorFlowNET.Examples/mnist
site/
/LLama.Unittest/Models/*.bin
+/LLama.Unittest/Models/*.gguf
diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index c589a270..832f3fdd 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -10,7 +10,7 @@ namespace LLama.Unittest
public BasicTest()
{
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
new file mode 100644
index 00000000..21328b41
--- /dev/null
+++ b/LLama.Unittest/Constants.cs
@@ -0,0 +1,7 @@
+namespace LLama.Unittest
+{
+ internal static class Constants
+ {
+ public static string ModelPath = "Models/llama-2-7b.q4_0.gguf";
+ }
+}
diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs
index dc2d3e3a..152ede93 100644
--- a/LLama.Unittest/GrammarTest.cs
+++ b/LLama.Unittest/GrammarTest.cs
@@ -12,7 +12,7 @@ namespace LLama.Unittest
public GrammarTest()
{
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048,
};
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index 03b865aa..ea0e100a 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -24,7 +24,7 @@
-
+
@@ -37,7 +37,7 @@
-
+
PreserveNewest
diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index c6d2dc21..e9c84eac 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -10,7 +10,7 @@ namespace LLama.Unittest
public LLamaContextTests()
{
- var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ var @params = new ModelParams(Constants.ModelPath)
{
ContextSize = 768,
};
diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
index 03487353..f94c90ba 100644
--- a/LLama.Unittest/LLamaEmbedderTests.cs
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -5,7 +5,7 @@ namespace LLama.Unittest;
public class LLamaEmbedderTests
: IDisposable
{
- private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));
+ private readonly LLamaEmbedder _embedder = new(new ModelParams(Constants.ModelPath));
public void Dispose()
{
@@ -36,18 +36,19 @@ public class LLamaEmbedderTests
Assert.Equal(expected[i], actual[i], epsilon);
}
- [Fact]
- public void EmbedBasic()
- {
- var cat = _embedder.GetEmbeddings("cat");
+ // todo: enable this one llama2 7B gguf is available
+ //[Fact]
+ //public void EmbedBasic()
+ //{
+ // var cat = _embedder.GetEmbeddings("cat");
- Assert.NotNull(cat);
- Assert.NotEmpty(cat);
+ // Assert.NotNull(cat);
+ // Assert.NotEmpty(cat);
- // Expected value generate with llama.cpp embedding.exe
- var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
- AssertApproxStartsWith(expected, cat);
- }
+ // // Expected value generate with llama.cpp embedding.exe
+ // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
+ // AssertApproxStartsWith(expected, cat);
+ //}
[Fact]
public void EmbedCompare()
diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs
index 317b1b85..413bda83 100644
--- a/LLama.Unittest/ModelsParamsTests.cs
+++ b/LLama.Unittest/ModelsParamsTests.cs
@@ -14,7 +14,6 @@ namespace LLama.Unittest
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
- GroupedQueryAttention = 7,
Seed = 42,
GpuLayerCount = 111
};
@@ -33,7 +32,6 @@ namespace LLama.Unittest
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
- GroupedQueryAttention = 7,
Seed = 42,
GpuLayerCount = 111
};
@@ -47,21 +45,26 @@ namespace LLama.Unittest
Assert.Equal(expected, actual);
}
- private class NewtsonsoftEncodingConverter
- : Newtonsoft.Json.JsonConverter
+
+
+ public class NewtsonsoftEncodingConverter : JsonConverter
{
- public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer)
+ public override bool CanConvert(Type objectType)
{
- writer.WriteValue((string?)value?.WebName);
+ return typeof(Encoding).IsAssignableFrom(objectType);
}
- public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer)
+ public override void WriteJson(JsonWriter writer, object value, JsonSerializer serializer)
{
- var name = (string?)reader.Value;
- if (name == null)
- return null;
- return Encoding.GetEncoding(name);
+ writer.WriteValue(((Encoding)value).WebName);
+ }
+
+ public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer)
+ {
+ return Encoding.GetEncoding((string)reader.Value);
}
}
+
+
}
}
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index 37031da3..1748e02d 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -13,7 +13,7 @@ namespace LLama.Unittest
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 60,
Seed = 1754
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 9a432858..f06757e3 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -88,16 +88,6 @@ namespace LLama.Web.Common
///
public float[] TensorSplits { get; set; }
- ///
- /// Grouped-Query Attention
- ///
- public int GroupedQueryAttention { get; set; } = 1;
-
- ///
- /// RMS Norm Epsilon
- ///
- public float RmsNormEpsilon { get; set; } = 5e-6f;
-
///
/// RoPE base frequency
///
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 4e4a2b4e..700d98e2 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -98,16 +98,6 @@ namespace LLama.Abstractions
///
float[]? TensorSplits { get; set; }
- ///
- /// Grouped-Query Attention
- ///
- int GroupedQueryAttention { get; set; }
-
- ///
- /// RMS Norm Epsilon
- ///
- float RmsNormEpsilon { get; set; }
-
///
/// RoPE base frequency
///
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index a9b573d4..e0b0c264 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -89,16 +89,6 @@ namespace LLama.Common
///
public float[]? TensorSplits { get; set; }
- ///
- /// Grouped-Query Attention
- ///
- public int GroupedQueryAttention { get; set; } = 1;
-
- ///
- /// RMS Norm Epsilon
- ///
- public float RmsNormEpsilon { get; set; } = 5e-6f;
-
///
/// RoPE base frequency
///
@@ -153,8 +143,6 @@ namespace LLama.Common
/// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// Whether to convert eos to newline during the inference.
/// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore.
- /// Grouped-Query Attention
- /// RMS Norm Epsilon
/// RoPE base frequency.
/// RoPE frequency scaling factor
/// Use experimental mul_mat_q kernels
@@ -165,7 +153,7 @@ namespace LLama.Common
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
bool convertEosToNewLine = false, bool embeddingMode = false,
- int groupedQueryAttention = 1, float rmsNormEpsilon = 5e-6f, float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
+ float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
string encoding = "UTF-8")
{
ContextSize = contextSize;
@@ -182,8 +170,6 @@ namespace LLama.Common
BatchSize = batchSize;
ConvertEosToNewLine = convertEosToNewLine;
EmbeddingMode = embeddingMode;
- GroupedQueryAttention = groupedQueryAttention;
- RmsNormEpsilon = rmsNormEpsilon;
RopeFrequencyBase = ropeFrequencyBase;
RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ;
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index 93b0f86e..c4cb1c62 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -39,8 +39,6 @@ namespace LLama.Extensions
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.low_vram = @params.LowVram;
- result.n_gqa = @params.GroupedQueryAttention;
- result.rms_norm_eps = @params.RmsNormEpsilon;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index fbb2107c..0cb77f60 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -132,9 +132,10 @@ namespace LLama
///
public string DeTokenize(IEnumerable tokens)
{
- StringBuilder sb = new();
- foreach(var token in tokens)
- sb.Append(_ctx.TokenToString(token, _encoding));
+ var sb = new StringBuilder();
+ foreach (var token in tokens)
+ _ctx.TokenToString(token, _encoding, sb);
+
return sb.ToString();
}
@@ -365,7 +366,7 @@ namespace LLama
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl();
+ var nl_token = NativeApi.llama_token_nl(_ctx);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 0bb3f669..64c17539 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -70,10 +70,6 @@ namespace LLama
///
public float[] GetEmbeddings(string text, bool addBos)
{
- if (addBos)
- {
- text = text.Insert(0, " ");
- }
var embed_inp_array = _ctx.Tokenize(text, addBos);
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 1a84ad2f..a7d53cc8 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -113,7 +114,6 @@ namespace LLama
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
- text = " " + text;
_embed_inps = Context.Tokenize(text, true).ToList();
}
else
@@ -141,9 +141,10 @@ namespace LLama
{
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
- string last_output = "";
- foreach (var id in _last_n_tokens)
- last_output += Context.NativeHandle.TokenToString(id, Context.Encoding);
+ var last_output_builder = new StringBuilder();
+ foreach (var token in _last_n_tokens)
+ Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
+ var last_output = last_output_builder.ToString();
foreach (var antiprompt in args.Antiprompts)
{
@@ -162,7 +163,7 @@ namespace LLama
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
args.WaitForInput = true;
}
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 595ddb3b..38d6b443 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -7,6 +7,7 @@ using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Text;
namespace LLama
{
@@ -25,7 +26,7 @@ namespace LLama
///
public InteractiveExecutor(LLamaContext context) : base(context)
{
- _llama_token_newline = Context.NativeHandle.Tokenize("\n", false, Context.Encoding);
+ _llama_token_newline = new [] { NativeApi.llama_token_nl(Context.NativeHandle) };
}
///
@@ -103,7 +104,6 @@ namespace LLama
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
- text = " " + text;
_embed_inps = Context.Tokenize(text, true).ToList();
}
else
@@ -132,11 +132,10 @@ namespace LLama
{
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
- string last_output = "";
- foreach (var id in _last_n_tokens)
- {
- last_output += Context.NativeHandle.TokenToString(id, Context.Encoding);
- }
+ var last_output_builder = new StringBuilder();
+ foreach (var token in _last_n_tokens)
+ Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
+ var last_output = last_output_builder.ToString();
foreach (var antiprompt in args.Antiprompts)
{
@@ -154,7 +153,7 @@ namespace LLama
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
extraOutputs = new[] { " [end of text]\n" };
return true;
@@ -215,7 +214,7 @@ namespace LLama
_last_n_tokens.Enqueue(id);
- if (id == NativeApi.llama_token_eos())
+ if (id == NativeApi.llama_token_eos(Context.NativeHandle))
{
id = _llama_token_newline.First();
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
diff --git a/LLama/LLamaSharp.Runtime.targets b/LLama/LLamaSharp.Runtime.targets
index e83b11ac..df079ba3 100644
--- a/LLama/LLamaSharp.Runtime.targets
+++ b/LLama/LLamaSharp.Runtime.targets
@@ -32,11 +32,11 @@
libllama.dylib
- PreserveNewest
+ None
libllama-metal.dylib
- PreserveNewest
+ None
ggml-metal.metal
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 74340c42..200301da 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -31,16 +31,6 @@ namespace LLama.Native
///
public int n_batch;
- ///
- /// grouped-query attention (TEMP - will be moved to model hparams)
- ///
- public int n_gqa;
-
- ///
- /// rms norm epsilon (TEMP - will be moved to model hparams)
- ///
- public float rms_norm_eps;
-
///
/// number of layers to store in VRAM
///
@@ -82,8 +72,8 @@ namespace LLama.Native
/// if true, reduce VRAM usage at the cost of performance
///
public bool low_vram
- {
- get => Convert.ToBoolean(_low_vram);
+ {
+ readonly get => Convert.ToBoolean(_low_vram);
set => _low_vram = Convert.ToSByte(value);
}
private sbyte _low_vram;
@@ -92,8 +82,8 @@ namespace LLama.Native
/// if true, use experimental mul_mat_q kernels
///
public bool mul_mat_q
- {
- get => Convert.ToBoolean(_mul_mat_q);
+ {
+ readonly get => Convert.ToBoolean(_mul_mat_q);
set => _mul_mat_q = Convert.ToSByte(value);
}
private sbyte _mul_mat_q;
@@ -102,8 +92,8 @@ namespace LLama.Native
/// use fp16 for KV cache
///
public bool f16_kv
- {
- get => Convert.ToBoolean(_f16_kv);
+ {
+ readonly get => Convert.ToBoolean(_f16_kv);
set => _f16_kv = Convert.ToSByte(value);
}
private sbyte _f16_kv;
@@ -112,8 +102,8 @@ namespace LLama.Native
/// the llama_eval() call computes all logits, not just the last one
///
public bool logits_all
- {
- get => Convert.ToBoolean(_logits_all);
+ {
+ readonly get => Convert.ToBoolean(_logits_all);
set => _logits_all = Convert.ToSByte(value);
}
private sbyte _logits_all;
@@ -122,8 +112,8 @@ namespace LLama.Native
/// only load the vocabulary, no weights
///
public bool vocab_only
- {
- get => Convert.ToBoolean(_vocab_only);
+ {
+ readonly get => Convert.ToBoolean(_vocab_only);
set => _vocab_only = Convert.ToSByte(value);
}
private sbyte _vocab_only;
@@ -132,8 +122,8 @@ namespace LLama.Native
/// use mmap if possible
///
public bool use_mmap
- {
- get => Convert.ToBoolean(_use_mmap);
+ {
+ readonly get => Convert.ToBoolean(_use_mmap);
set => _use_mmap = Convert.ToSByte(value);
}
private sbyte _use_mmap;
@@ -142,8 +132,8 @@ namespace LLama.Native
/// force system to keep model in RAM
///
public bool use_mlock
- {
- get => Convert.ToBoolean(_use_mlock);
+ {
+ readonly get => Convert.ToBoolean(_use_mlock);
set => _use_mlock = Convert.ToSByte(value);
}
private sbyte _use_mlock;
@@ -152,8 +142,8 @@ namespace LLama.Native
/// embedding mode only
///
public bool embedding
- {
- get => Convert.ToBoolean(_embedding);
+ {
+ readonly get => Convert.ToBoolean(_embedding);
set => _embedding = Convert.ToSByte(value);
}
private sbyte _embedding;
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 79fdf854..0fa0fbe9 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -105,5 +105,10 @@
///
/// Benchmark@7B: 5.15GB, +0.0044 ppl
LLAMA_FTYPE_MOSTLY_Q6_K = 18,
+
+ ///
+ /// File type was not specified
+ ///
+ LLAMA_FTYPE_GUESSED = 1024
}
}
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 0a1c7c1c..e9666ea8 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -11,12 +11,18 @@ namespace LLama.Native
{
using llama_token = Int32;
+ ///
+ /// Callback from llama.cpp with log messages
+ ///
+ ///
+ ///
public delegate void LLamaLogCallback(ILLamaLogger.LogLevel level, string message);
+ ///
+ /// Direct translation of the llama.cpp API
+ ///
public unsafe partial class NativeApi
{
- public static readonly int LLAMA_MAX_DEVICES = 1;
-
static NativeApi()
{
try
@@ -43,18 +49,43 @@ namespace LLama.Native
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
+ ///
+ /// Create a LLamaContextParams with default values
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaContextParams llama_context_default_params();
+ ///
+ /// Create a LLamaModelQuantizeParams with default values
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
+ ///
+ /// Check if memory mapping is supported
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mmap_supported();
+ ///
+ /// Check if memory lockingis supported
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported();
+ ///
+ /// Export a static computation graph for context of 511 and batch size of 1
+ /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
+ /// parameters here to keep things simple
+ /// IMPORTANT: do not use for anything else other than debugging and testing!
+ ///
+ ///
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname);
@@ -69,6 +100,13 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params);
+ ///
+ /// Create a new llama_context with the given model.
+ /// Return value should always be wrapped in SafeLLamaContextHandle!
+ ///
+ ///
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);
@@ -81,7 +119,7 @@ namespace LLama.Native
public static extern void llama_backend_init(bool numa);
///
- /// Frees all allocated memory
+ /// Frees all allocated memory in the given llama_context
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
@@ -341,14 +379,26 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token);
+ ///
+ /// Get the "Beginning of sentence" token
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_bos();
+ public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
+ ///
+ /// Get the "End of sentence" token
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_eos();
+ public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
+ ///
+ /// Get the "new line" token
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_nl();
+ public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
///
/// Print out timing information for this context
@@ -377,7 +427,7 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model);
+ public static extern int llama_model_n_vocab(SafeLlamaModelHandle model);
///
/// Get the size of the context window for the model
@@ -385,7 +435,7 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model);
+ public static extern int llama_model_n_ctx(SafeLlamaModelHandle model);
///
/// Get the dimension of embedding vectors from this model
@@ -393,16 +443,18 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model);
+ public static extern int llama_model_n_embd(SafeLlamaModelHandle model);
///
/// Convert a single token into text
///
///
///
- ///
+ /// buffer to write string into
+ /// size of the buffer
+ /// The length writte, or if the buffer is too small a negative that indicates the length required
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
+ public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
///
/// Convert text into tokens
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 86c1c71c..aa9c9439 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -183,7 +183,7 @@ namespace LLama.Native
///
/// Convert a token into a string
///
- ///
+ /// Token to decode into a string
///
///
public string TokenToString(int token, Encoding encoding)
@@ -192,13 +192,25 @@ namespace LLama.Native
}
///
- /// Convert a token into a span of bytes that could be decoded into a string
+ /// Append a single llama token to a string builder
///
- ///
- ///
- public ReadOnlySpan TokenToSpan(int token)
+ /// Token to decode
+ ///
+ /// string builder to append the result to
+ public void TokenToString(int token, Encoding encoding, StringBuilder dest)
+ {
+ ThrowIfDisposed().TokenToString(token, encoding, dest);
+ }
+
+ ///
+ /// Convert a single llama token into bytes
+ ///
+ /// Token to decode
+ /// A span to attempt to write into. If this is too small nothing will be written
+ /// The size of this token. **nothing will be written** if this is larger than `dest`
+ public int TokenToSpan(int token, Span dest)
{
- return ThrowIfDisposed().TokenToSpan(token);
+ return ThrowIfDisposed().TokenToSpan(token, dest);
}
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 665ad59f..7074fddb 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -1,4 +1,5 @@
using System;
+using System.Diagnostics;
using System.Text;
using LLama.Exceptions;
@@ -28,9 +29,9 @@ namespace LLama.Native
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
- VocabCount = NativeApi.llama_n_vocab_from_model(this);
- ContextSize = NativeApi.llama_n_ctx_from_model(this);
- EmbeddingSize = NativeApi.llama_n_embd_from_model(this);
+ VocabCount = NativeApi.llama_model_n_vocab(this);
+ ContextSize = NativeApi.llama_model_n_ctx(this);
+ EmbeddingSize = NativeApi.llama_model_n_embd(this);
}
///
@@ -82,17 +83,20 @@ namespace LLama.Native
#region tokenize
///
- /// Convert a single llama token into string bytes
+ /// Convert a single llama token into bytes
///
- ///
- ///
- public ReadOnlySpan TokenToSpan(int llama_token)
+ /// Token to decode
+ /// A span to attempt to write into. If this is too small nothing will be written
+ /// The size of this token. **nothing will be written** if this is larger than `dest`
+ public int TokenToSpan(int llama_token, Span dest)
{
unsafe
{
- var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
- var terminator = bytes.IndexOf((byte)0);
- return bytes.Slice(0, terminator);
+ fixed (byte* destPtr = dest)
+ {
+ var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length);
+ return Math.Abs(length);
+ }
}
}
@@ -104,16 +108,54 @@ namespace LLama.Native
///
public string TokenToString(int llama_token, Encoding encoding)
{
- var span = TokenToSpan(llama_token);
+ unsafe
+ {
+ var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
+ if (length == 0)
+ return "";
+
+ Span bytes = stackalloc byte[-length];
- if (span.Length == 0)
- return "";
+ fixed (byte* bytePtr = bytes)
+ {
+ var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
+ Debug.Assert(written == bytes.Length);
+ return encoding.GetString(bytePtr, bytes.Length);
+ }
+ }
+ }
+
+ ///
+ /// Append a single llama token to a string builder
+ ///
+ /// Token to decode
+ ///
+ /// string builder to append the result to
+ public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest)
+ {
unsafe
{
- fixed (byte* ptr = &span[0])
+ var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
+ if (length == 0)
+ return;
+
+ Span bytes = stackalloc byte[-length];
+ fixed (byte* bytePtr = bytes)
{
- return encoding.GetString(ptr, span.Length);
+ // Decode into bytes
+ var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
+ Debug.Assert(written == bytes.Length);
+
+ // Decode into chars
+ var charCount = encoding.GetCharCount(bytePtr, bytes.Length);
+ Span chars = stackalloc char[charCount];
+ fixed (char* charPtr = chars)
+ encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length);
+
+ // Write it to the output
+ for (var i = 0; i < chars.Length; i++)
+ dest.Append(chars[i]);
}
}
}
diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs
index 56771579..e26bf971 100644
--- a/LLama/Native/SamplingApi.cs
+++ b/LLama/Native/SamplingApi.cs
@@ -1,8 +1,14 @@
using System;
+#pragma warning disable IDE1006 // Naming Styles
+
namespace LLama.Native
{
using llama_token = Int32;
+
+ ///
+ /// Direct translation of the llama.cpp sampling API
+ ///
public unsafe class SamplingApi
{
///
@@ -140,6 +146,13 @@ namespace LLama.Native
NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
}
+ ///
+ /// Sample with temperature.
+ /// As temperature increases, the prediction becomes diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
+ ///
+ ///
+ ///
+ ///
public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
diff --git a/LLama/OldVersion/LLamaModel.cs b/LLama/OldVersion/LLamaModel.cs
index ec528ec4..523b9553 100644
--- a/LLama/OldVersion/LLamaModel.cs
+++ b/LLama/OldVersion/LLamaModel.cs
@@ -634,7 +634,7 @@ namespace LLama.OldVersion
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
// Apply penalties
- float nl_logit = logits[NativeApi.llama_token_nl()];
+ float nl_logit = logits[NativeApi.llama_token_nl(_ctx)];
var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
_last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
@@ -644,7 +644,7 @@ namespace LLama.OldVersion
(ulong)last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
- logits[NativeApi.llama_token_nl()] = nl_logit;
+ logits[NativeApi.llama_token_nl(_ctx)] = nl_logit;
}
if (temp <= 0)
@@ -684,7 +684,7 @@ namespace LLama.OldVersion
}
// replace end of text token with newline token when in interactive mode
- if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
+ if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct)
{
id = _llama_token_newline[0];
if (_params.antiprompt.Count != 0)
@@ -760,7 +760,7 @@ namespace LLama.OldVersion
break;
}
- if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
+ if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx))
{
if (_params.instruct)
{
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index 8d26b5ec..82e1a0c7 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -18,46 +18,11 @@ typedef struct {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
-static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
- const int qk = QK4_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const half d = x[i].d;
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
- const int x1 = (x[i].qs[j] >> 4) - 8;
-
- y[i*qk + j + 0 ] = x0*d;
- y[i*qk + j + qk/2] = x1*d;
- }
- }
-}
-
-static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
- const int qk = QK4_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const half d = x[i].d;
- const half m = x[i].m;
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F);
- const int x1 = (x[i].qs[j] >> 4);
-
- y[i*qk + j + 0 ] = x0*d + m;
- y[i*qk + j + qk/2] = x1*d + m;
- }
- }
-}
+#define QK8_0 32
+typedef struct {
+ half d; // delta
+ int8_t qs[QK8_0]; // quants
+} block_q8_0;
kernel void kernel_add(
device const float * src0,
@@ -128,7 +93,12 @@ kernel void kernel_gelu(
device float * dst,
uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig];
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_soft_max(
@@ -219,54 +189,6 @@ kernel void kernel_diag_mask_inf(
}
}
-kernel void kernel_get_rows_f16(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- for (int j = 0; j < ne00; j++) {
- dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
- }
-}
-
-kernel void kernel_get_rows_q4_0(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_0(
- (device const block_q4_0 *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_1(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_1(
- (device const block_q4_1 *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
kernel void kernel_norm(
device const void * src0,
device float * dst,
@@ -432,14 +354,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
// N_DST, so this is another explicit assumption of the implementation.
template
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
- int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
- uint2 tgpig, uint tiisg, uint sgitg) {
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
+ uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr;
- device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache
float sumf[nr]={0.f};
@@ -470,7 +394,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < ne01) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
}
@@ -480,13 +404,17 @@ kernel void kernel_mul_mat_q4_0_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_q4_1_f32(
@@ -494,13 +422,79 @@ kernel void kernel_mul_mat_q4_1_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mat_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int nr = N_DST;
+ const int nsg = N_SIMDGROUP;
+ const int nw = N_SIMDWIDTH;
+
+ const int nb = ne00/QK8_0;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * nsg + sgitg) * nr;
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float sumf[nr]={0.f};
+
+ const int ix = tiisg/2;
+ const int il = tiisg%2;
+
+ device const float * yb = y + ix * QK8_0 + 16*il;
+
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
+ for (int i = 0; i < 16; ++i) {
+ yl[i] = yb[i];
+ }
+
+ for (int row = 0; row < nr; row++) {
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+ float sumq = 0.f;
+ for (int iq = 0; iq < 16; ++iq) {
+ sumq += qs[iq] * yl[iq];
+ }
+ sumf[row] += sumq*x[ib+row*nb].d;
+ }
+
+ yb += QK8_0 * 16;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ }
+ }
}
kernel void kernel_mul_mat_f16_f32(
@@ -554,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
}
}
-
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,
@@ -650,7 +643,25 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
- // TODO: implement
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
+ const float cos_theta = cos(theta);
+ const float sin_theta = sin(theta);
+
+ theta *= theta_scale;
+
+ const int64_t i0 = ib*n_dims + ic/2;
+
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ }
+ }
}
}
@@ -869,354 +880,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
return r;
}
-//========================================== dequantization =============================
-
-static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- const float d = x[i].d;
- const float min = x[i].dmin;
-
- device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
- int is = 0;
- float dl, ml;
- for (int n = 0; n < QK_K; n += 128) {
- int shift = 0;
- for (int j = 0; j < 4; ++j) {
-
- uint8_t sc = x[i].scales[is++];
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
-
- sc = x[i].scales[is++];
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
-
- shift += 2;
- }
- q += 32;
- }
-#else
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
- for (int l = 0; l < 16; ++l) {
- y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
- y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
- y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
- y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
- }
- y += QK_K;
-#endif
-
- }
-}
-
-static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
-#if QK_K == 256
-
- const uint16_t kmask1 = 0x0303;
- const uint16_t kmask2 = 0x0f0f;
-
- uint16_t aux[8];
- thread const int8_t * scales = (thread const int8_t*)aux;
-
- for (int i = 0; i < nb; i++) {
-
- const float d_all = (float)(x[i].d);
-
- device const uint8_t * q = x[i].qs;
- device const uint8_t * h = x[i].hmask;
- uint8_t m = 1;
-
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
- aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
- aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
- aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
- aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
- aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
- aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
- aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
- aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
-
- int is = 0;
- float dl;
- for (int n = 0; n < QK_K; n += 128) {
- int shift = 0;
- for (int j = 0; j < 4; ++j) {
-
- dl = d_all * (scales[is++] - 32);
- for (int l = 0; l < 16; ++l) {
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
- }
-
- dl = d_all * (scales[is++] - 32);
- for (int l = 0; l < 16; ++l) {
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
- }
-
- shift += 2;
- m <<= 1;
- }
- q += 32;
- }
- }
-#else
- for (int i = 0; i < nb; i++) {
-
- const float d_all = (float)(x[i].d);
-
- device const uint8_t * q = x[i].qs;
- device const uint8_t * hm = x[i].hmask;
-
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
-
- for (int l = 0; l < 8; ++l) {
- uint8_t h = hm[l];
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
- }
- y += QK_K;
- }
-#endif
-
-}
-
-static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
- const float d = x[i].d;
- const float min = x[i].dmin;
-
- device const uint8_t * scales = x[i].scales;
-
- int is = 0;
- for (int j = 0; j < QK_K; j += 64) {
- const uchar4 sc = get_scale_min_k4(is, scales);
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
- q += 32; is += 2;
- }
-#else
- device const uint8_t * s = x[i].scales;
- device const half2 * dh = (device const half2 *)x[i].d;
- const float2 d = (float2)dh[0];
- const float d1 = d[0] * (s[0] & 0xF);
- const float d2 = d[0] * (s[1] & 0xF);
- const float m1 = d[1] * (s[0] >> 4);
- const float m2 = d[1] * (s[1] >> 4);
- for (int l = 0; l < 32; ++l) {
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
- y[l+32] = d2 * (q[l] >> 4) - m2;
- }
- y += QK_K;
-#endif
-
- }
-}
-
-static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
-#if QK_K == 256
- for (int i = 0; i < nb; i++) {
-
- const float d = (float)(x[i].d);
- const float min = (float)(x[i].dmin);
-
- device const uint8_t * ql = x[i].qs;
- device const uint8_t * qh = x[i].qh;
-
- int is = 0;
- uint8_t u1 = 1, u2 = 2;
- for (int j = 0; j < QK_K; j += 64) {
- const uchar4 sc = get_scale_min_k4(is, x[i].scales);
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
- ql += 32; is += 2;
- u1 <<= 2; u2 <<= 2;
- }
- }
-#else
- for (int i = 0; i < nb; i++) {
-
- const float d = (float)x[i].d;
-
- device const uint8_t * ql = x[i].qs;
- device const uint8_t * qh = x[i].qh;
- device const int8_t * sc = x[i].scales;
-
- for (int l = 0; l < 8; ++l) {
- y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
- y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
- y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
- y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
- y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
- y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
- y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
- y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
- }
- y += QK_K;
- }
-#endif
-
-}
-
-static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- device const uint8_t * ql = x[i].ql;
- device const uint8_t * qh = x[i].qh;
- device const int8_t * sc = x[i].scales;
-
- const float d = x[i].d;
-
-#if QK_K == 256
- for (int n = 0; n < QK_K; n += 128) {
- for (int l = 0; l < 32; ++l) {
- int is = l/16;
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l + 0] = d * sc[is + 0] * q1;
- y[l + 32] = d * sc[is + 2] * q2;
- y[l + 64] = d * sc[is + 4] * q3;
- y[l + 96] = d * sc[is + 6] * q4;
- }
- y += 128;
- ql += 64;
- qh += 32;
- sc += 8;
- }
-#else
- for (int l = 0; l < 16; ++l) {
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l+ 0] = d * sc[0] * q1;
- y[l+16] = d * sc[1] * q2;
- y[l+32] = d * sc[2] * q3;
- y[l+48] = d * sc[3] * q4;
- }
- y += 64;
-#endif
- }
-}
-
-kernel void kernel_get_rows_q2_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q2_K(
- (device const block_q2_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q3_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q3_K(
- (device const block_q3_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_K(
- (device const block_q4_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q5_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q5_K(
- (device const block_q5_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q6_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q6_K(
- (device const block_q6_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
//====================================== dot products =========================
kernel void kernel_mul_mat_q2_K_f32(
@@ -1224,21 +887,27 @@ kernel void kernel_mul_mat_q2_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@@ -1351,7 +1020,7 @@ kernel void kernel_mul_mat_q2_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -1362,10 +1031,14 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- constant int64_t & ne1,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1373,11 +1046,12 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int64_t r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
@@ -1465,7 +1139,7 @@ kernel void kernel_mul_mat_q3_K_f32(
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}
}
@@ -1475,10 +1149,14 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- constant int64_t & ne1,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1486,11 +1164,12 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int64_t r2 = tgpig.z;
const int row = 2 * r0 + sgitg;
-
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
@@ -1529,7 +1208,7 @@ kernel void kernel_mul_mat_q3_K_f32(
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
}
}
@@ -1541,10 +1220,14 @@ kernel void kernel_mul_mat_q4_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1560,10 +1243,12 @@ kernel void kernel_mul_mat_q4_K_f32(
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
@@ -1630,7 +1315,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -1640,10 +1325,14 @@ kernel void kernel_mul_mat_q4_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1653,10 +1342,12 @@ kernel void kernel_mul_mat_q4_K_f32(
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[8];
float yh[8];
float sumf[N_DST]={0.f}, all_sum;
@@ -1712,7 +1403,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -1723,9 +1414,14 @@ kernel void kernel_mul_mat_q5_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1733,11 +1429,12 @@ kernel void kernel_mul_mat_q5_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf[2]={0.f};
@@ -1871,7 +1568,7 @@ kernel void kernel_mul_mat_q5_K_f32(
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}
@@ -1882,9 +1579,14 @@ kernel void kernel_mul_mat_q6_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1897,11 +1599,12 @@ kernel void kernel_mul_mat_q6_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int row = 2 * r0 + sgitg;
-
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf = 0;
@@ -1967,6 +1670,380 @@ kernel void kernel_mul_mat_q6_K_f32(
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+ }
+}
+
+//============================= templates and their specializations =============================
+
+template
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const half d = il ? (xb->d / 16.h) : xb->d;
+ const half m = il ? ( -8.h * 16.h) : -8.h;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+ }
+}
+
+template
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const half d = il ? (xb->d / 16.h) : xb->d;
+ const half m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+ }
+}
+
+template
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const half d = xb->d;
+
+ for (int i=0;i<16;i++) {
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
+ }
+}
+
+template
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const half d = xb->d;
+ const half min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ half dl, ml;
+ uint8_t sc = xb->scales[il];
+
+#if QK_K == 256
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
+#endif
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
+
+template
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const float d_all = (float)(xb->d);
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+
+ il = (il/2)%4;
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+ }
+#else
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ uint8_t m = 1<<(il*2);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
+ }
+#endif
+}
+
+template
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+
+#if QK_K == 256
+ const float d = (float)(xb->d);
+ const float min = (float)(xb->dmin);
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il%4;
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
+#else
+ q = q + 16 * (il&1);
+ device const uint8_t * s = xb->scales;
+ device const half2 * dh = (device const half2 *)xb->d;
+ const float2 d = (float2)dh[0];
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
+#endif
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
+
+#if QK_K == 256
+ const float d = (float)(xb->d);
+ const float min = (float)(xb->dmin);
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il%4;
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
+#else
+ q = q + 16 * (il&1);
+ device const int8_t * s = xb->scales;
+ const float dl = xb->d * s[il];
+ uint8_t m = 1<<(il*2);
+ const float coef = il<2 ? 1.f : 1.f/16.f;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
+ }
+#endif
+}
+
+template
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const float d_all = (float)(xb->d);
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ float sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2)%4;
+#else
+ ql = ql + 16 * (il&1);
+ float sc = scales[il];
+#endif
+ for (int i = 0; i < 16; ++i) {
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
+ reg[i/4][i%4] = d_all * sc * q * coef;
+ }
+}
+
+template
+kernel void kernel_get_rows(
+ device const void * src0,
+ device const int * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb1,
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tptg[[threads_per_threadgroup]]) {
+ const int i = tgpig;
+ const int r = ((device int32_t *) src1)[i];
+
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ float4x4 temp;
+ dequantize_func(
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ }
+}
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & gqa,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ //load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ #pragma unroll(16)
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
+ = *((device float2x4 *)y);
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ //load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+ #pragma unroll(4)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ if (sgitg==0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+#if QK_K == 256
+#define QK_NL 16
+#else
+#define QK_NL 4
+#endif
+
+typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
+
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows;
+
+typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll
index 81af173b..6ed31810 100644
Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ
diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so
index 75b884dd..81733cdd 100644
Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ
diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll
index e6ff0a30..f1a9fbdc 100644
Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ
diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so
index 6d20557b..482fe2f2 100644
Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ
diff --git a/LLama/runtimes/libllama-metal.dylib b/LLama/runtimes/libllama-metal.dylib
old mode 100644
new mode 100755
index 7cd1f4ab..e9c2ee28
Binary files a/LLama/runtimes/libllama-metal.dylib and b/LLama/runtimes/libllama-metal.dylib differ
diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll
index 8432f664..a5f774f8 100644
Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
old mode 100644
new mode 100755
index e4d0f1c7..53318c38
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ
diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so
index 1d7226a6..e52d6bda 100644
Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ
diff --git a/docs/ContributingGuide.md b/docs/ContributingGuide.md
index c7f28b7c..1f3b3d47 100644
--- a/docs/ContributingGuide.md
+++ b/docs/ContributingGuide.md
@@ -33,11 +33,11 @@ When adding the feature, please take care of the namespace and the naming conven
## Find the problem and fix the BUG
-If the issue is related to the LLM internal behaviors, such as endless generating the response, the best way to find the problem is to do comparison test between llama.cpp and LLamaSharp.
+If the issue is related to the LLM internal behaviour, such as endless generating the response, the best way to find the problem is to do comparison test between llama.cpp and LLamaSharp.
You could use exactly the same prompt, the same model and the same parameters to run the inference in llama.cpp and LLamaSharp respectively to see if it's really a problem caused by the implementation in LLamaSharp.
-If the experiment showed that it worked well in llama.cpp but didn't in LLamaSharp, a the search for the problem could be started. While the reason of the problem could be various, the best way I think is to add log-print in the code of llama.cpp and use it in LLamaSharp after compilation. Thus, when running LLamaSharp, you could see what happened in the native library.
+If the experiment showed that it worked well in llama.cpp but didn't in LLamaSharp, a search for the problem could be started. While the reason of the problem could be various, the best way I think is to add log-print in the code of llama.cpp and use it in LLamaSharp after compilation. Thus, when running LLamaSharp, you could see what happened in the native library.
After finding out the reason, a painful but happy process comes. When working on the BUG fix, there's only one rule to follow, that is keeping the examples working well. If the modification fixed the BUG but impact on other functions, it would not be a good fix.
diff --git a/docs/GetStarted.md b/docs/GetStarted.md
index db46f2a3..b9248b57 100644
--- a/docs/GetStarted.md
+++ b/docs/GetStarted.md
@@ -54,8 +54,16 @@ using LLama;
string modelPath = "" // change it to your own model path
var prompt = "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:"; // use the "chat-with-bob" prompt here.
+// Load model
+var parameters = new ModelParams(modelPath)
+{
+ ContextSize = 1024
+};
+using var model = LLamaWeights.LoadFromFile(parameters);
+
// Initialize a chat session
-var ex = new InteractiveExecutor(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
+using var context = model.CreateContext(parameters);
+var ex = new InteractiveExecutor(context);
ChatSession session = new ChatSession(ex);
// show the prompt
diff --git a/docs/Tricks.md b/docs/Tricks.md
index a75d6c21..4b72f440 100644
--- a/docs/Tricks.md
+++ b/docs/Tricks.md
@@ -1,11 +1,11 @@
# Tricks for FAQ
-Sometimes, your application with LLM and LLamaSharp may have strange behaviors. Before opening an issue to report the BUG, the following tricks may worth a try.
+Sometimes, your application with LLM and LLamaSharp may have strange behaviours. Before opening an issue to report the BUG, the following tricks may worth a try.
## Carefully set the anti-prompts
-Anti-prompt can also be called as "Stop-keyword", which decides when to stop the response generation. Under interactive mode, the maximum tokens count is always not set, which makes the LLM generates responses infinitively. Therefore, setting anti-prompt correctly helps a lot to avoid the strange behaviors. For example, the prompt file `chat-with-bob.txt` has the following content:
+Anti-prompt can also be called as "Stop-keyword", which decides when to stop the response generation. Under interactive mode, the maximum tokens count is always not set, which makes the LLM generates responses infinitively. Therefore, setting anti-prompt correctly helps a lot to avoid the strange behaviours. For example, the prompt file `chat-with-bob.txt` has the following content:
```
Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
@@ -19,7 +19,7 @@ User:
Therefore, the anti-prompt should be set as "User:". If the last line of the prompt is removed, LLM will automatically generate a question (user) and a response (bob) for one time when running the chat session. Therefore, the antiprompt is suggested to be appended to the prompt when starting a chat session.
-What if an extra line is appended? The string "User:" in the prompt will be followed with a char "\n". Thus when running the model, the automatic generation of a pair of question and response may appear because the anti-prompt is "User:" but the last token is "User:\n". As for whether it will appear, it's an undefined behavior, which depends on the implementation inside the `LLamaExecutor`. Anyway, since it may leads to unexpected behaviors, it's recommended to trim your prompt or carefully keep consistent with your anti-prompt.
+What if an extra line is appended? The string "User:" in the prompt will be followed with a char "\n". Thus when running the model, the automatic generation of a pair of question and response may appear because the anti-prompt is "User:" but the last token is "User:\n". As for whether it will appear, it's an undefined behaviour, which depends on the implementation inside the `LLamaExecutor`. Anyway, since it may leads to unexpected behaviors, it's recommended to trim your prompt or carefully keep consistent with your anti-prompt.
## Pay attention to the length of prompt
@@ -37,7 +37,7 @@ If your chat bot has bad performance, trying different executor will possibly ma
## Choose models weight depending on you task
-The differences between modes may lead to much different behaviors under the same task. For example, if you're building a chat bot with non-English, a fine-tuned model specially for the language you want to use will have huge effect on the performance.
+The differences between modes may lead to much different behaviours under the same task. For example, if you're building a chat bot with non-English, a fine-tuned model specially for the language you want to use will have huge effect on the performance.
## Set the layer count you want to offload to GPU