New Binaries, Improved Sampling API, Batch Decoding Prototypetags/v0.7.0
| @@ -8,6 +8,7 @@ | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults --> | |||
| <IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -0,0 +1,177 @@ | |||
| using System.Diagnostics; | |||
| using System.Security.Cryptography; | |||
| using System.Text; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Examples.NewVersion; | |||
| /// <summary> | |||
| /// This demonstrates generating multiple replies to the same prompt, with a shared cache | |||
| /// </summary> | |||
| /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks> | |||
| public class BatchedDecoding | |||
| { | |||
| private const int n_parallel = 8; | |||
| private const int n_len = 32; | |||
| private const int top_k = 80; | |||
| private const float top_p = 0.8f; | |||
| private const float temp = 0.5f; | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| Console.WriteLine("Prompt (leave blank to select automatically):"); | |||
| var prompt = Console.ReadLine(); | |||
| if (string.IsNullOrWhiteSpace(prompt)) | |||
| prompt = "Not many people know that"; | |||
| // Load model | |||
| var parameters = new ModelParams(modelPath); | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| // Tokenize prompt | |||
| var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); | |||
| var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; | |||
| // Create a context | |||
| parameters.ContextSize = (uint)model.ContextSize; | |||
| parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); | |||
| using var context = model.CreateContext(parameters); | |||
| var n_ctx = context.ContextSize; | |||
| // make sure the KV cache is big enough to hold all the prompt and generated tokens | |||
| if (n_kv_req > n_ctx) | |||
| { | |||
| await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); | |||
| await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); | |||
| return; | |||
| } | |||
| using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); | |||
| // evaluate the initial prompt | |||
| for (var i = 0; i < prompt_tokens.Length; i++) | |||
| batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false); | |||
| Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length); | |||
| // llama_decode will output logits only for the last token of the prompt | |||
| unsafe | |||
| { | |||
| batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; | |||
| } | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("llama_decode failed"); | |||
| return; | |||
| } | |||
| // assign the system KV cache to all parallel sequences | |||
| // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | |||
| for (var i = 1; i < n_parallel; ++i) | |||
| { | |||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); | |||
| } | |||
| if (n_parallel > 1) | |||
| { | |||
| Console.WriteLine(); | |||
| Console.WriteLine($"generating {n_parallel} sequences..."); | |||
| } | |||
| // remember the batch index of the last token for each parallel sequence | |||
| // we need this to determine which logits to sample from | |||
| List<int> i_batch = new(); | |||
| for (var i = 0; i < n_parallel; i++) | |||
| i_batch.Add(batch.NativeBatch.n_tokens - 1); | |||
| var n_cur = batch.NativeBatch.n_tokens; | |||
| var n_decode = 0; | |||
| var streams = new List<int>[n_parallel]; | |||
| for (var i = 0; i < n_parallel; i++) | |||
| streams[i] = new(); | |||
| var eos = model.EndOfSentenceToken; | |||
| var nl = model.NewlineToken; | |||
| var timer = new Stopwatch(); | |||
| timer.Start(); | |||
| while (n_cur <= n_len) | |||
| { | |||
| batch.LLamaBatchClear(); | |||
| for (var i = 0; i < n_parallel; i++) | |||
| { | |||
| // Skip completed streams | |||
| if (i_batch[i] < 0) | |||
| continue; | |||
| var n_vocab = model.VocabCount; | |||
| LLamaTokenDataArray candidates; | |||
| unsafe | |||
| { | |||
| candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); | |||
| } | |||
| candidates.TopK(context.NativeHandle, top_k); | |||
| candidates.TopP(context.NativeHandle, top_p); | |||
| candidates.Temperature(context.NativeHandle, temp); | |||
| var new_token_id = candidates.SampleToken(context.NativeHandle); | |||
| if (new_token_id == eos || new_token_id == nl) | |||
| { | |||
| i_batch[i] = -1; | |||
| Console.WriteLine($"Completed Stream {i} early"); | |||
| continue; | |||
| } | |||
| streams[i].Add(new_token_id); | |||
| i_batch[i] = batch.NativeBatch.n_tokens; | |||
| // push this new token for next evaluation | |||
| batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); | |||
| n_decode++; | |||
| } | |||
| // all streams are finished | |||
| if (batch.NativeBatch.n_tokens == 0) | |||
| { | |||
| break; | |||
| } | |||
| n_cur++; | |||
| // evaluate the current batch with the transformer model | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("failed to eval"); | |||
| return; | |||
| } | |||
| } | |||
| timer.Stop(); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine(); | |||
| Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); | |||
| Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); | |||
| var index = 0; | |||
| foreach (var stream in streams) | |||
| { | |||
| var text = context.DeTokenize(stream); | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| Console.Write($"{index++}. {prompt}"); | |||
| Console.ForegroundColor = ConsoleColor.Red; | |||
| Console.WriteLine(text); | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,7 @@ namespace LLama.Examples.NewVersion | |||
| var modelPath = Console.ReadLine(); | |||
| // Load weights into memory | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)), | |||
| }; | |||
| var parameters = new ModelParams(modelPath); | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var ex = new InteractiveExecutor(context); | |||
| @@ -16,10 +16,7 @@ namespace LLama.Examples.NewVersion | |||
| var modelPath = Console.ReadLine(); | |||
| // Load weights into memory | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||
| }; | |||
| var parameters = new ModelParams(modelPath); | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| var ex = new StatelessExecutor(model, parameters); | |||
| @@ -13,10 +13,7 @@ namespace LLama.Examples.NewVersion | |||
| var modelPath = Console.ReadLine(); | |||
| // Load weights into memory | |||
| var @params = new ModelParams(modelPath) | |||
| { | |||
| Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) | |||
| }; | |||
| var @params = new ModelParams(modelPath); | |||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||
| // Create 2 contexts sharing the same weights | |||
| @@ -22,6 +22,7 @@ | |||
| Console.WriteLine("12: Semantic Kernel Chat."); | |||
| Console.WriteLine("13: Semantic Kernel Memory."); | |||
| Console.WriteLine("14: Coding Assistant."); | |||
| Console.WriteLine("15: Batch Decoding."); | |||
| while (true) | |||
| { | |||
| @@ -88,6 +89,10 @@ | |||
| { | |||
| await CodingAssistant.Run(); | |||
| } | |||
| else if (choice == 15) | |||
| { | |||
| await BatchedDecoding.Run(); | |||
| } | |||
| else | |||
| { | |||
| Console.WriteLine("Cannot parse your choice. Please select again."); | |||
| @@ -1,3 +1,4 @@ | |||
| using System.Diagnostics; | |||
| using LLama.Common; | |||
| using Xunit.Abstractions; | |||
| @@ -34,10 +35,17 @@ namespace LLama.Unittest | |||
| const string question = "Question. what is a cat?\nAnswer: "; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | |||
| var timer = new Stopwatch(); | |||
| timer.Start(); | |||
| var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| timer.Stop(); | |||
| _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms"); | |||
| _testOutputHelper.WriteLine(result1); | |||
| _testOutputHelper.WriteLine(result2); | |||
| // Check that it produced the exact same result both times | |||
| Assert.Equal(result1, result2); | |||
| @@ -23,7 +23,7 @@ namespace LLama.Web.Common | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// path to file for saving/loading model eval state | |||
| /// </summary> | |||
| @@ -111,12 +111,12 @@ namespace LLama.Web.Common | |||
| /// <summary> | |||
| /// RoPE base frequency | |||
| /// </summary> | |||
| public float RopeFrequencyBase { get; set; } = 10000.0f; | |||
| public float? RopeFrequencyBase { get; set; } | |||
| /// <summary> | |||
| /// RoPE frequency scaling factor | |||
| /// </summary> | |||
| public float RopeFrequencyScale { get; set; } = 1.0f; | |||
| public float? RopeFrequencyScale { get; set; } | |||
| /// <summary> | |||
| /// Use experimental mul_mat_q kernels | |||
| @@ -39,14 +39,14 @@ public interface IContextParams | |||
| bool EmbeddingMode { get; set; } | |||
| /// <summary> | |||
| /// RoPE base frequency | |||
| /// RoPE base frequency (null to fetch from the model) | |||
| /// </summary> | |||
| float RopeFrequencyBase { get; set; } | |||
| float? RopeFrequencyBase { get; set; } | |||
| /// <summary> | |||
| /// RoPE frequency scaling factor | |||
| /// RoPE frequency scaling factor (null to fetch from the model) | |||
| /// </summary> | |||
| float RopeFrequencyScale { get; set; } | |||
| float? RopeFrequencyScale { get; set; } | |||
| /// <summary> | |||
| /// Use experimental mul_mat_q kernels | |||
| @@ -29,7 +29,7 @@ namespace LLama.Abstractions | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IEnumerable<string> AntiPrompts { get; set; } | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| @@ -28,7 +28,7 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// Sequences where the model will stop generating further tokens. | |||
| /// </summary> | |||
| public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>(); | |||
| /// <summary> | |||
| /// 0 or lower to use vocab size | |||
| @@ -91,12 +91,12 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// RoPE base frequency | |||
| /// </summary> | |||
| public float RopeFrequencyBase { get; set; } = 10000.0f; | |||
| public float? RopeFrequencyBase { get; set; } | |||
| /// <summary> | |||
| /// RoPE frequency scaling factor | |||
| /// </summary> | |||
| public float RopeFrequencyScale { get; set; } = 1.0f; | |||
| public float? RopeFrequencyScale { get; set; } | |||
| /// <summary> | |||
| /// Use experimental mul_mat_q kernels | |||
| @@ -156,7 +156,7 @@ namespace LLama.Common | |||
| bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, | |||
| string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512, | |||
| bool embeddingMode = false, | |||
| float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, | |||
| float? ropeFrequencyBase = null, float? ropeFrequencyScale = null, bool mulMatQ = false, | |||
| string encoding = "UTF-8") | |||
| { | |||
| ContextSize = contextSize; | |||
| @@ -27,8 +27,8 @@ namespace LLama.Extensions | |||
| result.f16_kv = @params.UseFp16Memory; | |||
| result.logits_all = @params.Perplexity; | |||
| result.embedding = @params.EmbeddingMode; | |||
| result.rope_freq_base = @params.RopeFrequencyBase; | |||
| result.rope_freq_scale = @params.RopeFrequencyScale; | |||
| result.rope_freq_base = @params.RopeFrequencyBase ?? 0; | |||
| result.rope_freq_scale = @params.RopeFrequencyScale ?? 0; | |||
| result.mul_mat_q = @params.MulMatQ; | |||
| result.n_threads = Threads(@params.Threads); | |||
| @@ -235,13 +235,13 @@ namespace LLama | |||
| if (grammar != null) | |||
| { | |||
| SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); | |||
| candidates.ApplyGrammar(NativeHandle, grammar); | |||
| } | |||
| if (temperature <= 0) | |||
| { | |||
| // Greedy sampling | |||
| id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); | |||
| id = candidates.SampleTokenGreedy(NativeHandle); | |||
| } | |||
| else | |||
| { | |||
| @@ -250,32 +250,28 @@ namespace LLama | |||
| if (mirostat == MirostatType.Mirostat) | |||
| { | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| } | |||
| else if (mirostat == MirostatType.Mirostat2) | |||
| { | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); | |||
| } | |||
| else | |||
| { | |||
| // Temperature sampling | |||
| SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(NativeHandle, candidates); | |||
| candidates.TopK(NativeHandle, topK); | |||
| candidates.TailFree(NativeHandle, tfsZ); | |||
| candidates.LocallyTypical(NativeHandle, typicalP); | |||
| candidates.TopP(NativeHandle, topP); | |||
| candidates.Temperature(NativeHandle, temperature); | |||
| id = candidates.SampleToken(NativeHandle); | |||
| } | |||
| } | |||
| mirostat_mu = mu; | |||
| } | |||
| if (grammar != null) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); | |||
| } | |||
| grammar?.AcceptToken(NativeHandle, id); | |||
| return id; | |||
| } | |||
| @@ -305,7 +301,7 @@ namespace LLama | |||
| } | |||
| // Save the newline logit value | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle); | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||
| var nl_logit = logits[nl_token]; | |||
| // Convert logits into token candidates | |||
| @@ -316,8 +312,7 @@ namespace LLama | |||
| var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | |||
| // Apply penalties to candidates | |||
| SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); | |||
| SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||
| candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); | |||
| // Restore newline token logit value if necessary | |||
| if (!penalizeNL) | |||
| @@ -369,7 +364,7 @@ namespace LLama | |||
| try | |||
| { | |||
| tokens.CopyTo(rented, 0); | |||
| return Eval(rented, pastTokensCount); | |||
| return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount); | |||
| } | |||
| finally | |||
| { | |||
| @@ -163,7 +163,7 @@ namespace LLama | |||
| } | |||
| } | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||
| { | |||
| args.WaitForInput = true; | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace LLama | |||
| public InteractiveExecutor(LLamaContext context, ILogger? logger = null) | |||
| : base(context, logger) | |||
| { | |||
| _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); | |||
| _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -141,7 +141,7 @@ namespace LLama | |||
| return (true, Array.Empty<string>()); | |||
| } | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||
| { | |||
| return (true, new[] { " [end of text]\n" }); | |||
| } | |||
| @@ -202,7 +202,7 @@ namespace LLama | |||
| _last_n_tokens.Enqueue(id); | |||
| if (id == NativeApi.llama_token_eos(Context.NativeHandle)) | |||
| if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||
| { | |||
| id = _llama_token_newline; | |||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||
| @@ -6,7 +6,6 @@ using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Extensions; | |||
| using LLama.Native; | |||
| using Microsoft.Extensions.Logging; | |||
| @@ -47,68 +46,66 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| using var context = _weights.CreateContext(_params, _logger); | |||
| Context = context; | |||
| // Ensure the context from last time is disposed (it always hould be) | |||
| if (!Context.NativeHandle.IsClosed) | |||
| Context.Dispose(); | |||
| Context = _weights.CreateContext(Context.Params, _logger); | |||
| var decoder = new StreamingTokenDecoder(Context); | |||
| var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>()); | |||
| if (inferenceParams != null) | |||
| { | |||
| if (inferenceParams.TokensKeep > Context.ContextSize) | |||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); | |||
| } | |||
| cancellationToken.ThrowIfCancellationRequested(); | |||
| // Create an inference context which will be disposed when this method exits | |||
| using var context = _weights.CreateContext(_params, _logger); | |||
| Context = context; | |||
| // Sanity check inference params | |||
| inferenceParams ??= new InferenceParams(); | |||
| if (inferenceParams.TokensKeep > Context.ContextSize) | |||
| throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); | |||
| // Create decoders for the token stream | |||
| var decoder = new StreamingTokenDecoder(Context); | |||
| var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); | |||
| var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount); | |||
| for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) | |||
| // Keep track of the last N tokens emitted | |||
| var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); | |||
| var lastTokens = new List<llama_token>(repeat_last_n); | |||
| for (var i = 0; i < repeat_last_n; i++) | |||
| lastTokens.Add(0); | |||
| var tokens = Context.Tokenize(text).ToList(); | |||
| // Tokenize the prompt | |||
| var tokens = Context.Tokenize(prompt).ToList(); | |||
| lastTokens.AddRange(tokens); | |||
| var n_past = 1 + tokens.Count; | |||
| // Evaluate the prompt | |||
| await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) | |||
| .ConfigureAwait(false); | |||
| lastTokens.AddRange(tokens); | |||
| var n_past = 1 + tokens.Count; | |||
| // Begin loop, evaluating one token at a time | |||
| var mu = (float?)null; | |||
| var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(var i = 0; i < max_tokens; i++) | |||
| for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) | |||
| { | |||
| if (cancellationToken.IsCancellationRequested) | |||
| break; | |||
| var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; | |||
| // Penalize the generated tokens by various penalties | |||
| var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| // Sample a single token | |||
| var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); | |||
| lastTokens.Add(id); | |||
| // Decode this token into text | |||
| decoder.Add(id); | |||
| var decoded = decoder.Read(); | |||
| yield return decoded; | |||
| tokens.Clear(); | |||
| tokens.Add(id); | |||
| // Check if any of the antiprompts have been generated | |||
| if (antiprocessor.Add(decoded)) | |||
| break; | |||
| lastTokens.Add(id); | |||
| tokens.Clear(); | |||
| tokens.Add(id); | |||
| // when run out of context | |||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | |||
| if (n_past + tokens.Count >= Context.ContextSize) | |||
| @@ -38,6 +38,21 @@ namespace LLama | |||
| /// </summary> | |||
| public ulong ParameterCount => NativeHandle.ParameterCount; | |||
| /// <summary> | |||
| /// Get the newline token for this model | |||
| /// </summary> | |||
| public int NewlineToken => NativeApi.llama_token_nl(NativeHandle); | |||
| /// <summary> | |||
| /// Get the "end of sentence" token for this model | |||
| /// </summary> | |||
| public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); | |||
| /// <summary> | |||
| /// Get the "beginning of sentence" token for this model | |||
| /// </summary> | |||
| public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| @@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle | |||
| /// <summary> | |||
| /// Get the native llama_batch struct | |||
| /// </summary> | |||
| public LLamaNativeBatch NativeBatch { get; private set; } | |||
| public LLamaNativeBatch NativeBatch; | |||
| /// <summary> | |||
| /// the token ids of the input (used when embd is NULL) | |||
| @@ -113,10 +113,11 @@ public sealed class LLamaBatchSafeHandle | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="embd"></param> | |||
| /// <param name="n_seq_max"></param> | |||
| /// <returns></returns> | |||
| public static LLamaBatchSafeHandle Create(int n_tokens, int embd) | |||
| public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) | |||
| { | |||
| var batch = NativeApi.llama_batch_init(n_tokens, embd); | |||
| var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); | |||
| return new LLamaBatchSafeHandle(batch, embd); | |||
| } | |||
| @@ -128,4 +129,32 @@ public sealed class LLamaBatchSafeHandle | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||
| /// </summary> | |||
| public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| unsafe | |||
| { | |||
| NativeBatch.token[NativeBatch.n_tokens] = token; | |||
| NativeBatch.pos[NativeBatch.n_tokens] = pos; | |||
| NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; | |||
| for (var i = 0; i < sequences.Length; i++) | |||
| NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; | |||
| NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); | |||
| NativeBatch.n_tokens++; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 | |||
| /// </summary> | |||
| public void LLamaBatchClear() | |||
| { | |||
| NativeBatch.n_tokens = 0; | |||
| } | |||
| } | |||
| @@ -11,13 +11,13 @@ using llama_token = Int32; | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaBeamView | |||
| { | |||
| private readonly unsafe llama_token* tokens; | |||
| private readonly nint n_tokens; | |||
| private unsafe llama_token* tokens; | |||
| private nint n_tokens; | |||
| /// <summary> | |||
| /// Cumulative beam probability (renormalized relative to all beams) | |||
| /// </summary> | |||
| public readonly float CumulativeProbability; | |||
| public float CumulativeProbability; | |||
| /// <summary> | |||
| /// Callback should set this to true when a beam is at end-of-beam. | |||
| @@ -9,27 +9,27 @@ namespace LLama.Native; | |||
| /// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public readonly struct LLamaBeamsState | |||
| public struct LLamaBeamsState | |||
| { | |||
| /// <summary> | |||
| /// The state of each individual beam | |||
| /// </summary> | |||
| private readonly unsafe LLamaBeamView* beam_views; | |||
| private unsafe LLamaBeamView* beam_views; | |||
| /// <summary> | |||
| /// Number of elements in beam_views | |||
| /// </summary> | |||
| private readonly nint n_beams; | |||
| private nint n_beams; | |||
| /// <summary> | |||
| /// Current max length of prefix tokens shared by all beams. | |||
| /// </summary> | |||
| public readonly ulong CommonPrefixLength; | |||
| public ulong CommonPrefixLength; | |||
| /// <summary> | |||
| /// True iff this is the last callback invocation. | |||
| /// </summary> | |||
| public readonly bool LastCall; | |||
| public bool LastCall; | |||
| /// <summary> | |||
| /// The current state of each beam | |||
| @@ -52,18 +52,18 @@ namespace LLama.Native | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| [DebuggerDisplay("{Type} {Value}")] | |||
| public readonly struct LLamaGrammarElement | |||
| public struct LLamaGrammarElement | |||
| : IEquatable<LLamaGrammarElement> | |||
| { | |||
| /// <summary> | |||
| /// The type of this element | |||
| /// </summary> | |||
| public readonly LLamaGrammarElementType Type; | |||
| public LLamaGrammarElementType Type; | |||
| /// <summary> | |||
| /// Unicode code point or rule ID | |||
| /// </summary> | |||
| public readonly uint Value; | |||
| public uint Value; | |||
| /// <summary> | |||
| /// Construct a new LLamaGrammarElement | |||
| @@ -1,10 +1,12 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// Quantizer parameters used in the native API | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaModelQuantizeParams | |||
| { | |||
| /// <summary> | |||
| @@ -11,35 +11,47 @@ using llama_token = Int32; | |||
| /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public readonly unsafe struct LLamaNativeBatch | |||
| public unsafe struct LLamaNativeBatch | |||
| { | |||
| /// <summary> | |||
| /// The number of items pointed at by pos, seq_id and logits. | |||
| /// </summary> | |||
| public readonly int n_tokens; | |||
| public int n_tokens; | |||
| /// <summary> | |||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | |||
| /// </summary> | |||
| public readonly llama_token* token; | |||
| public llama_token* token; | |||
| /// <summary> | |||
| /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created | |||
| /// </summary> | |||
| public readonly float* embd; | |||
| public float* embd; | |||
| /// <summary> | |||
| /// the positions of the respective token in the sequence | |||
| /// </summary> | |||
| public readonly LLamaPos* pos; | |||
| public LLamaPos* pos; | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ??? | |||
| /// </summary> | |||
| public int* n_seq_id; | |||
| /// <summary> | |||
| /// the sequence to which the respective token belongs | |||
| /// </summary> | |||
| public readonly LLamaSeqId* seq_id; | |||
| public LLamaSeqId** seq_id; | |||
| /// <summary> | |||
| /// if zero, the logits for the respective token will not be output | |||
| /// </summary> | |||
| public readonly byte* logits; | |||
| public byte* logits; | |||
| // Note from llama.cpp: | |||
| // > helpers for smooth API transition - can be deprecated in the future | |||
| // > for future-proof code, use the above fields instead and ignore everything below | |||
| private LLamaPos _all_pos_0; | |||
| private LLamaPos _all_pos_1; | |||
| private LLamaSeqId _all_seq_id; | |||
| } | |||
| @@ -1,14 +1,26 @@ | |||
| namespace LLama.Native; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// Indicates position in a sequence | |||
| /// </summary> | |||
| public readonly record struct LLamaPos(int Value) | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaPos | |||
| { | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| public readonly int Value = Value; | |||
| public int Value; | |||
| /// <summary> | |||
| /// Create a new LLamaPos | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| public LLamaPos(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| /// <summary> | |||
| /// Convert a LLamaPos into an integer (extract the raw value) | |||
| @@ -1,15 +1,26 @@ | |||
| namespace LLama.Native; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// ID for a sequence in a batch | |||
| /// </summary> | |||
| /// <param name="Value"></param> | |||
| public record struct LLamaSeqId(int Value) | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaSeqId | |||
| { | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| public int Value = Value; | |||
| public int Value; | |||
| /// <summary> | |||
| /// Create a new LLamaSeqId | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| public LLamaSeqId(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| /// <summary> | |||
| /// Convert a LLamaSeqId into an integer (extract the raw value) | |||
| @@ -5,24 +5,34 @@ namespace LLama.Native; | |||
| /// <summary> | |||
| /// A single token along with probability of this token being selected | |||
| /// </summary> | |||
| /// <param name="id"></param> | |||
| /// <param name="logit"></param> | |||
| /// <param name="p"></param> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public record struct LLamaTokenData(int id, float logit, float p) | |||
| public struct LLamaTokenData | |||
| { | |||
| /// <summary> | |||
| /// token id | |||
| /// </summary> | |||
| public int id = id; | |||
| public int id; | |||
| /// <summary> | |||
| /// log-odds of the token | |||
| /// </summary> | |||
| public float logit = logit; | |||
| public float logit; | |||
| /// <summary> | |||
| /// probability of the token | |||
| /// </summary> | |||
| public float p = p; | |||
| public float p; | |||
| /// <summary> | |||
| /// Create a new LLamaTokenData | |||
| /// </summary> | |||
| /// <param name="id"></param> | |||
| /// <param name="logit"></param> | |||
| /// <param name="p"></param> | |||
| public LLamaTokenData(int id, float logit, float p) | |||
| { | |||
| this.id = id; | |||
| this.logit = logit; | |||
| this.p = p; | |||
| } | |||
| } | |||
| @@ -45,6 +45,199 @@ namespace LLama.Native | |||
| return new LLamaTokenDataArray(candidates); | |||
| } | |||
| #region sampling | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="grammar"></param> | |||
| public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="k">Number of tokens to keep</param> | |||
| /// <param name="minKeep">Minimum number to keep</param> | |||
| public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_top_k(context, ref st, k, minKeep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p"></param> | |||
| /// <param name="minKeep"></param> | |||
| public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_top_p(context, ref st, p, minKeep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="z"></param> | |||
| /// <param name="min_keep"></param> | |||
| public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_typical(context, ref st, p, min_keep); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="penalty_repeat"></param> | |||
| /// <param name="penalty_freq"></param> | |||
| /// <param name="penalty_present"></param> | |||
| public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) | |||
| { | |||
| unsafe | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| using (var last_tokens_handle = last_tokens.Pin()) | |||
| { | |||
| NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sample with temperature. | |||
| /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="temp"></param> | |||
| public void Temperature(SafeLLamaContextHandle context, float temp) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_temperature(context, ref st, temp); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| public void Softmax(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| NativeApi.llama_sample_softmax(context, ref st); | |||
| sorted = st.sorted; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Randomly selects a token from the candidates based on their probabilities. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleToken(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token(context, ref st); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Selects the token with the highest probability. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <returns></returns> | |||
| public int SampleTokenGreedy(SafeLLamaContextHandle context) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_greedy(context, ref st); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. | |||
| /// </summary> | |||
| /// <param name="context"></param> | |||
| /// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param> | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) | |||
| { | |||
| using (LLamaTokenDataArrayNative.Create(this, out var st)) | |||
| { | |||
| var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); | |||
| sorted = st.sorted; | |||
| return token; | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| /// <summary> | |||
| @@ -9,26 +9,22 @@ namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| /// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param> | |||
| /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); | |||
| public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, | |||
| ref LLamaTokenDataArrayNative candidates, | |||
| llama_token* last_tokens, ulong last_tokens_size, | |||
| float penalty_repeat, | |||
| float penalty_freq, | |||
| float penalty_present); | |||
| /// <summary> | |||
| /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 | |||
| @@ -348,9 +348,10 @@ namespace LLama.Native | |||
| /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="i"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); | |||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | |||
| /// <summary> | |||
| /// Get the embeddings for the input | |||
| @@ -366,21 +367,21 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); | |||
| public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Get the "End of sentence" token | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); | |||
| public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Get the "new line" token | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); | |||
| public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); | |||
| /// <summary> | |||
| /// Print out timing information for this context | |||
| @@ -530,6 +531,7 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Allocates a batch of tokens on the heap | |||
| /// Each token can be assigned up to n_seq_max sequence ids | |||
| /// The batch has to be freed with llama_batch_free() | |||
| /// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) | |||
| /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||
| @@ -538,8 +540,9 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="embd"></param> | |||
| /// <param name="n_seq_max">Each token can be assigned up to n_seq_max sequence ids</param> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd); | |||
| public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd, int n_seq_max); | |||
| /// <summary> | |||
| /// Frees a batch of tokens allocated with llama_batch_init() | |||
| @@ -114,6 +114,22 @@ namespace LLama.Native | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab | |||
| /// </summary> | |||
| /// <param name="i"></param> | |||
| /// <returns></returns> | |||
| public Span<float> GetLogitsIth(int i) | |||
| { | |||
| var model = ThrowIfDisposed(); | |||
| unsafe | |||
| { | |||
| var logits = NativeApi.llama_get_logits_ith(this, i); | |||
| return new Span<float>(logits, model.VocabCount); | |||
| } | |||
| } | |||
| #region tokens | |||
| /// <summary> | |||
| /// Convert the given text into tokens | |||
| @@ -102,5 +102,15 @@ namespace LLama.Native | |||
| return new(grammar_ptr); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Accepts the sampled token into the grammar | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="token"></param> | |||
| public void AcceptToken(SafeLLamaContextHandle ctx, int token) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(ctx, this, token); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Direct translation of the llama.cpp sampling API | |||
| /// </summary> | |||
| public unsafe class SamplingApi | |||
| public class SamplingApi | |||
| { | |||
| /// <summary> | |||
| /// Apply grammar rules to candidate tokens | |||
| @@ -17,70 +17,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="grammar"></param> | |||
| [Obsolete("use LLamaTokenDataArray ApplyGrammar method")] | |||
| public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_grammar(ctx, ref st, grammar); | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="penalty"></param> | |||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty) | |||
| { | |||
| llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); | |||
| } | |||
| /// <summary> | |||
| /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="penalty"></param> | |||
| public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); | |||
| } | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="last_tokens_size"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| [Obsolete("last_tokens_size parameter is no longer needed")] | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) | |||
| { | |||
| llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); | |||
| } | |||
| /// <summary> | |||
| /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="last_tokens"></param> | |||
| /// <param name="alpha_frequency"></param> | |||
| /// <param name="alpha_presence"></param> | |||
| public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| using var last_tokens_handle = last_tokens.Pin(); | |||
| NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); | |||
| candidates.ApplyGrammar(ctx, grammar); | |||
| } | |||
| /// <summary> | |||
| @@ -88,10 +28,10 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| [Obsolete("use LLamaTokenDataArray Softmax method")] | |||
| public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_softmax(ctx, ref st); | |||
| candidates.Softmax(ctx); | |||
| } | |||
| /// <summary> | |||
| @@ -101,10 +41,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="k"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TopK method")] | |||
| public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); | |||
| candidates.TopK(ctx, k, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -114,10 +54,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TopP method")] | |||
| public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); | |||
| candidates.TopP(ctx, p, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -127,10 +67,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="z"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray TailFree method")] | |||
| public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); | |||
| candidates.TailFree(ctx, z, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -140,10 +80,10 @@ namespace LLama.Native | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <param name="p"></param> | |||
| /// <param name="min_keep"></param> | |||
| [Obsolete("use LLamaTokenDataArray LocallyTypical method")] | |||
| public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); | |||
| candidates.LocallyTypical(ctx, p, min_keep); | |||
| } | |||
| /// <summary> | |||
| @@ -153,10 +93,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates"></param> | |||
| /// <param name="temp"></param> | |||
| [Obsolete("use LLamaTokenDataArray Temperature() method")] | |||
| public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| NativeApi.llama_sample_temperature(ctx, ref st, temp); | |||
| candidates.Temperature(ctx, temp); | |||
| } | |||
| /// <summary> | |||
| @@ -169,10 +109,10 @@ namespace LLama.Native | |||
| /// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] | |||
| public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); | |||
| return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); | |||
| } | |||
| /// <summary> | |||
| @@ -184,10 +124,10 @@ namespace LLama.Native | |||
| /// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param> | |||
| /// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] | |||
| public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); | |||
| return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); | |||
| } | |||
| /// <summary> | |||
| @@ -196,10 +136,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] | |||
| public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token_greedy(ctx, ref st); | |||
| return candidates.SampleTokenGreedy(ctx); | |||
| } | |||
| /// <summary> | |||
| @@ -208,10 +148,10 @@ namespace LLama.Native | |||
| /// <param name="ctx"></param> | |||
| /// <param name="candidates">Pointer to LLamaTokenDataArray</param> | |||
| /// <returns></returns> | |||
| [Obsolete("use LLamaTokenDataArray SampleToken() method")] | |||
| public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) | |||
| { | |||
| using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); | |||
| return NativeApi.llama_sample_token(ctx, ref st); | |||
| return candidates.SampleToken(ctx); | |||
| } | |||
| } | |||
| } | |||
| @@ -18,6 +18,21 @@ typedef struct { | |||
| uint8_t qs[QK4_1 / 2]; // nibbles / quants | |||
| } block_q4_1; | |||
| #define QK5_0 32 | |||
| typedef struct { | |||
| half d; // delta | |||
| uint8_t qh[4]; // 5-th bit of quants | |||
| uint8_t qs[QK5_0 / 2]; // nibbles / quants | |||
| } block_q5_0; | |||
| #define QK5_1 32 | |||
| typedef struct { | |||
| half d; // delta | |||
| half m; // min | |||
| uint8_t qh[4]; // 5-th bit of quants | |||
| uint8_t qs[QK5_1 / 2]; // nibbles / quants | |||
| } block_q5_1; | |||
| #define QK8_0 32 | |||
| typedef struct { | |||
| half d; // delta | |||
| @@ -110,9 +125,17 @@ kernel void kernel_mul_row( | |||
| } | |||
| kernel void kernel_scale( | |||
| device const float * src0, | |||
| device float * dst, | |||
| constant float & scale, | |||
| uint tpig[[thread_position_in_grid]]) { | |||
| dst[tpig] = src0[tpig] * scale; | |||
| } | |||
| kernel void kernel_scale_4( | |||
| device const float4 * src0, | |||
| device float4 * dst, | |||
| constant float & scale, | |||
| constant float & scale, | |||
| uint tpig[[thread_position_in_grid]]) { | |||
| dst[tpig] = src0[tpig] * scale; | |||
| } | |||
| @@ -399,8 +422,11 @@ kernel void kernel_rms_norm( | |||
| // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | |||
| inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { | |||
| float d = qb_curr->d; | |||
| float2 acc = 0.f; | |||
| device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); | |||
| for (int i = 0; i < 8; i+=2) { | |||
| acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) | |||
| + yl[i + 1] * (qs[i / 2] & 0x0F00); | |||
| @@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre | |||
| inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { | |||
| float d = qb_curr->d; | |||
| float m = qb_curr->m; | |||
| device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); | |||
| float2 acc = 0.f; | |||
| device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); | |||
| for (int i = 0; i < 8; i+=2) { | |||
| acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) | |||
| + yl[i + 1] * (qs[i / 2] & 0x0F00); | |||
| @@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre | |||
| return d * (acc[0] + acc[1]) + sumy * m; | |||
| } | |||
| // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) | |||
| // il indicates where the q5 quants begin (0 or QK5_0/4) | |||
| // we assume that the yl's have been multiplied with the appropriate scale factor | |||
| // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | |||
| inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { | |||
| float d = qb_curr->d; | |||
| float2 acc = 0.f; | |||
| device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); | |||
| const uint32_t qh = *((device const uint32_t *)qb_curr->qh); | |||
| for (int i = 0; i < 8; i+=2) { | |||
| acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) | |||
| + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); | |||
| acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) | |||
| + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | |||
| } | |||
| return d * (sumy * -16.f + acc[0] + acc[1]); | |||
| } | |||
| // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) | |||
| // il indicates where the q5 quants begin (0 or QK5_1/4) | |||
| // we assume that the yl's have been multiplied with the appropriate scale factor | |||
| // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | |||
| inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { | |||
| float d = qb_curr->d; | |||
| float m = qb_curr->m; | |||
| float2 acc = 0.f; | |||
| device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); | |||
| const uint32_t qh = *((device const uint32_t *)qb_curr->qh); | |||
| for (int i = 0; i < 8; i+=2) { | |||
| acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) | |||
| + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); | |||
| acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) | |||
| + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | |||
| } | |||
| return d * (acc[0] + acc[1]) + sumy * m; | |||
| } | |||
| // putting them in the kernel cause a significant performance penalty | |||
| #define N_DST 4 // each SIMD group works on 4 rows | |||
| #define N_SIMDGROUP 2 // number of SIMD groups in a thread group | |||
| @@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32( | |||
| mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | |||
| } | |||
| kernel void kernel_mul_mv_q5_0_f32( | |||
| device const void * src0, | |||
| device const float * src1, | |||
| device float * dst, | |||
| constant int64_t & ne00, | |||
| constant int64_t & ne01[[buffer(4)]], | |||
| constant int64_t & ne02[[buffer(5)]], | |||
| constant int64_t & ne10[[buffer(9)]], | |||
| constant int64_t & ne12[[buffer(11)]], | |||
| constant int64_t & ne0[[buffer(15)]], | |||
| constant int64_t & ne1[[buffer(16)]], | |||
| constant uint & gqa[[buffer(17)]], | |||
| uint3 tgpig[[threadgroup_position_in_grid]], | |||
| uint tiisg[[thread_index_in_simdgroup]], | |||
| uint sgitg[[simdgroup_index_in_threadgroup]]) { | |||
| mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | |||
| } | |||
| kernel void kernel_mul_mv_q5_1_f32( | |||
| device const void * src0, | |||
| device const float * src1, | |||
| device float * dst, | |||
| constant int64_t & ne00, | |||
| constant int64_t & ne01[[buffer(4)]], | |||
| constant int64_t & ne02[[buffer(5)]], | |||
| constant int64_t & ne10[[buffer(9)]], | |||
| constant int64_t & ne12[[buffer(11)]], | |||
| constant int64_t & ne0[[buffer(15)]], | |||
| constant int64_t & ne1[[buffer(16)]], | |||
| constant uint & gqa[[buffer(17)]], | |||
| uint3 tgpig[[threadgroup_position_in_grid]], | |||
| uint tiisg[[thread_index_in_simdgroup]], | |||
| uint sgitg[[simdgroup_index_in_threadgroup]]) { | |||
| mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | |||
| } | |||
| #define NB_Q8_0 8 | |||
| kernel void kernel_mul_mv_q8_0_f32( | |||
| @@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg | |||
| } | |||
| } | |||
| template <typename type4x4> | |||
| void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { | |||
| device const uint16_t * qs = ((device const uint16_t *)xb + 3); | |||
| const float d = xb->d; | |||
| const float md = -16.h * xb->d; | |||
| const ushort mask = il ? 0x00F0 : 0x000F; | |||
| const uint32_t qh = *((device const uint32_t *)xb->qh); | |||
| const int x_mv = il ? 4 : 0; | |||
| const int gh_mv = il ? 12 : 0; | |||
| const int gh_bk = il ? 0 : 4; | |||
| for (int i = 0; i < 8; i++) { | |||
| // extract the 5-th bits for x0 and x1 | |||
| const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | |||
| const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | |||
| // combine the 4-bits from qs with the 5th bit | |||
| const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | |||
| const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | |||
| reg[i/2][2*(i%2)+0] = d * x0 + md; | |||
| reg[i/2][2*(i%2)+1] = d * x1 + md; | |||
| } | |||
| } | |||
| template <typename type4x4> | |||
| void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { | |||
| device const uint16_t * qs = ((device const uint16_t *)xb + 4); | |||
| const float d = xb->d; | |||
| const float m = xb->m; | |||
| const ushort mask = il ? 0x00F0 : 0x000F; | |||
| const uint32_t qh = *((device const uint32_t *)xb->qh); | |||
| const int x_mv = il ? 4 : 0; | |||
| const int gh_mv = il ? 12 : 0; | |||
| const int gh_bk = il ? 0 : 4; | |||
| for (int i = 0; i < 8; i++) { | |||
| // extract the 5-th bits for x0 and x1 | |||
| const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | |||
| const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | |||
| // combine the 4-bits from qs with the 5th bit | |||
| const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | |||
| const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | |||
| reg[i/2][2*(i%2)+0] = d * x0 + m; | |||
| reg[i/2][2*(i%2)+1] = d * x1 + m; | |||
| } | |||
| } | |||
| template <typename type4x4> | |||
| void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { | |||
| device const int8_t * qs = ((device const int8_t *)xb->qs); | |||
| @@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows | |||
| template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; | |||
| template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; | |||
| template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; | |||
| template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>; | |||
| template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; | |||
| template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; | |||
| template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; | |||
| template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; | |||
| @@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f | |||
| template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; | |||
| template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; | |||
| template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; | |||
| template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>; | |||
| template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>; | |||
| template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; | |||
| template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; | |||
| template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; | |||