Browse Source

Merge pull request #223 from martindevans/batch_decoding

New Binaries, Improved Sampling API, Batch Decoding Prototype
tags/v0.7.0
Martin Evans GitHub 2 years ago
parent
commit
5a9e13c689
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 814 additions and 222 deletions
  1. +1
    -0
      LLama.Examples/LLama.Examples.csproj
  2. +177
    -0
      LLama.Examples/NewVersion/BatchedDecoding.cs
  3. +1
    -4
      LLama.Examples/NewVersion/SemanticKernelChat.cs
  4. +1
    -4
      LLama.Examples/NewVersion/SemanticKernelPrompt.cs
  5. +1
    -4
      LLama.Examples/NewVersion/TalkToYourself.cs
  6. +5
    -0
      LLama.Examples/NewVersion/TestRunner.cs
  7. +8
    -0
      LLama.Unittest/StatelessExecutorTest.cs
  8. +1
    -1
      LLama.Web/Common/InferenceOptions.cs
  9. +2
    -2
      LLama.Web/Common/ModelOptions.cs
  10. +4
    -4
      LLama/Abstractions/IContextParams.cs
  11. +1
    -1
      LLama/Abstractions/IInferenceParams.cs
  12. +1
    -1
      LLama/Common/InferenceParams.cs
  13. +3
    -3
      LLama/Common/ModelParams.cs
  14. +2
    -2
      LLama/Extensions/IContextParamsExtensions.cs
  15. +16
    -21
      LLama/LLamaContext.cs
  16. +1
    -1
      LLama/LLamaInstructExecutor.cs
  17. +3
    -3
      LLama/LLamaInteractExecutor.cs
  18. +30
    -33
      LLama/LLamaStatelessExecutor.cs
  19. +15
    -0
      LLama/LLamaWeights.cs
  20. +32
    -3
      LLama/Native/LLamaBatchSafeHandle.cs
  21. +3
    -3
      LLama/Native/LLamaBeamView.cs
  22. +5
    -5
      LLama/Native/LLamaBeamsState.cs
  23. +3
    -3
      LLama/Native/LLamaGrammarElement.cs
  24. +2
    -0
      LLama/Native/LLamaModelQuantizeParams.cs
  25. +19
    -7
      LLama/Native/LLamaNativeBatch.cs
  26. +15
    -3
      LLama/Native/LLamaPos.cs
  27. +15
    -4
      LLama/Native/LLamaSeqId.cs
  28. +17
    -7
      LLama/Native/LLamaTokenData.cs
  29. +193
    -0
      LLama/Native/LLamaTokenDataArray.cs
  30. +9
    -13
      LLama/Native/NativeApi.Sampling.cs
  31. +8
    -5
      LLama/Native/NativeApi.cs
  32. +16
    -0
      LLama/Native/SafeLLamaContextHandle.cs
  33. +10
    -0
      LLama/Native/SafeLLamaGrammarHandle.cs
  34. +23
    -83
      LLama/Native/SamplingApi.cs
  35. +171
    -2
      LLama/runtimes/ggml-metal.metal
  36. BIN
      LLama/runtimes/libllama-cuda11.dll
  37. BIN
      LLama/runtimes/libllama-cuda11.so
  38. BIN
      LLama/runtimes/libllama-cuda12.dll
  39. BIN
      LLama/runtimes/libllama-cuda12.so
  40. BIN
      LLama/runtimes/libllama.dll
  41. BIN
      LLama/runtimes/libllama.dylib
  42. BIN
      LLama/runtimes/libllama.so

+ 1
- 0
LLama.Examples/LLama.Examples.csproj View File

@@ -8,6 +8,7 @@
<Platforms>AnyCPU;x64</Platforms>
<!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults -->
<IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 177
- 0
LLama.Examples/NewVersion/BatchedDecoding.cs View File

@@ -0,0 +1,177 @@
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.NewVersion;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
/// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;

private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.5f;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Load model
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt
var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;

// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);

var n_ctx = context.ContextSize;

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}

using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);

// llama_decode will output logits only for the last token of the prompt
unsafe
{
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}

if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
}

if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);

var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();

var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;

var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
batch.LLamaBatchClear();

for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;

var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);

if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

streams[i].Add(new_token_id);

i_batch[i] = batch.NativeBatch.n_tokens;

// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);

n_decode++;
}

// all streams are finished
if (batch.NativeBatch.n_tokens == 0)
{
break;
}

n_cur++;

// evaluate the current batch with the transformer model
if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}

timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in streams)
{
var text = context.DeTokenize(stream);

Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}
}
}

+ 1
- 4
LLama.Examples/NewVersion/SemanticKernelChat.cs View File

@@ -14,10 +14,7 @@ namespace LLama.Examples.NewVersion
var modelPath = Console.ReadLine();

// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
};
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);


+ 1
- 4
LLama.Examples/NewVersion/SemanticKernelPrompt.cs View File

@@ -16,10 +16,7 @@ namespace LLama.Examples.NewVersion
var modelPath = Console.ReadLine();

// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);



+ 1
- 4
LLama.Examples/NewVersion/TalkToYourself.cs View File

@@ -13,10 +13,7 @@ namespace LLama.Examples.NewVersion
var modelPath = Console.ReadLine();

// Load weights into memory
var @params = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
var @params = new ModelParams(modelPath);
using var weights = LLamaWeights.LoadFromFile(@params);

// Create 2 contexts sharing the same weights


+ 5
- 0
LLama.Examples/NewVersion/TestRunner.cs View File

@@ -22,6 +22,7 @@
Console.WriteLine("12: Semantic Kernel Chat.");
Console.WriteLine("13: Semantic Kernel Memory.");
Console.WriteLine("14: Coding Assistant.");
Console.WriteLine("15: Batch Decoding.");

while (true)
{
@@ -88,6 +89,10 @@
{
await CodingAssistant.Run();
}
else if (choice == 15)
{
await BatchedDecoding.Run();
}
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");


+ 8
- 0
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -1,3 +1,4 @@
using System.Diagnostics;
using LLama.Common;
using Xunit.Abstractions;

@@ -34,10 +35,17 @@ namespace LLama.Unittest
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };

var timer = new Stopwatch();
timer.Start();

var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

timer.Stop();
_testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");

_testOutputHelper.WriteLine(result1);
_testOutputHelper.WriteLine(result2);

// Check that it produced the exact same result both times
Assert.Equal(result1, result2);


+ 1
- 1
LLama.Web/Common/InferenceOptions.cs View File

@@ -23,7 +23,7 @@ namespace LLama.Web.Common
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>


+ 2
- 2
LLama.Web/Common/ModelOptions.cs View File

@@ -111,12 +111,12 @@ namespace LLama.Web.Common
/// <summary>
/// RoPE base frequency
/// </summary>
public float RopeFrequencyBase { get; set; } = 10000.0f;
public float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
public float RopeFrequencyScale { get; set; } = 1.0f;
public float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels


+ 4
- 4
LLama/Abstractions/IContextParams.cs View File

@@ -39,14 +39,14 @@ public interface IContextParams
bool EmbeddingMode { get; set; }

/// <summary>
/// RoPE base frequency
/// RoPE base frequency (null to fetch from the model)
/// </summary>
float RopeFrequencyBase { get; set; }
float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// RoPE frequency scaling factor (null to fetch from the model)
/// </summary>
float RopeFrequencyScale { get; set; }
float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels


+ 1
- 1
LLama/Abstractions/IInferenceParams.cs View File

@@ -29,7 +29,7 @@ namespace LLama.Abstractions
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; }
public IReadOnlyList<string> AntiPrompts { get; set; }

/// <summary>
/// 0 or lower to use vocab size


+ 1
- 1
LLama/Common/InferenceParams.cs View File

@@ -28,7 +28,7 @@ namespace LLama.Common
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

/// <summary>
/// 0 or lower to use vocab size


+ 3
- 3
LLama/Common/ModelParams.cs View File

@@ -91,12 +91,12 @@ namespace LLama.Common
/// <summary>
/// RoPE base frequency
/// </summary>
public float RopeFrequencyBase { get; set; } = 10000.0f;
public float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
public float RopeFrequencyScale { get; set; } = 1.0f;
public float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
@@ -156,7 +156,7 @@ namespace LLama.Common
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512,
bool embeddingMode = false,
float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
float? ropeFrequencyBase = null, float? ropeFrequencyScale = null, bool mulMatQ = false,
string encoding = "UTF-8")
{
ContextSize = contextSize;


+ 2
- 2
LLama/Extensions/IContextParamsExtensions.cs View File

@@ -27,8 +27,8 @@ namespace LLama.Extensions
result.f16_kv = @params.UseFp16Memory;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
result.mul_mat_q = @params.MulMatQ;

result.n_threads = Threads(@params.Threads);


+ 16
- 21
LLama/LLamaContext.cs View File

@@ -235,13 +235,13 @@ namespace LLama

if (grammar != null)
{
SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
candidates.ApplyGrammar(NativeHandle, grammar);
}

if (temperature <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
id = candidates.SampleTokenGreedy(NativeHandle);
}
else
{
@@ -250,32 +250,28 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token(NativeHandle, candidates);
candidates.TopK(NativeHandle, topK);
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}
}
mirostat_mu = mu;
}

if (grammar != null)
{
NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
}
grammar?.AcceptToken(NativeHandle, id);

return id;
}
@@ -305,7 +301,7 @@ namespace LLama
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle);
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];

// Convert logits into token candidates
@@ -316,8 +312,7 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();

// Apply penalties to candidates
SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence);

// Restore newline token logit value if necessary
if (!penalizeNL)
@@ -369,7 +364,7 @@ namespace LLama
try
{
tokens.CopyTo(rented, 0);
return Eval(rented, pastTokensCount);
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{


+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -163,7 +163,7 @@ namespace LLama
}
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
args.WaitForInput = true;
}


+ 3
- 3
LLama/LLamaInteractExecutor.cs View File

@@ -30,7 +30,7 @@ namespace LLama
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle);
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
}

/// <inheritdoc />
@@ -141,7 +141,7 @@ namespace LLama
return (true, Array.Empty<string>());
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
return (true, new[] { " [end of text]\n" });
}
@@ -202,7 +202,7 @@ namespace LLama

_last_n_tokens.Enqueue(id);

if (id == NativeApi.llama_token_eos(Context.NativeHandle))
if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
id = _llama_token_newline;
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)


+ 30
- 33
LLama/LLamaStatelessExecutor.cs View File

@@ -6,7 +6,6 @@ using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;

@@ -47,68 +46,66 @@ namespace LLama
}

/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var context = _weights.CreateContext(_params, _logger);
Context = context;

// Ensure the context from last time is disposed (it always hould be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
Context = _weights.CreateContext(Context.Params, _logger);

var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>());

if (inferenceParams != null)
{
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
}

cancellationToken.ThrowIfCancellationRequested();
// Create an inference context which will be disposed when this method exits
using var context = _weights.CreateContext(_params, _logger);
Context = context;

// Sanity check inference params
inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

// Create decoders for the token stream
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);

var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<llama_token>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);

var tokens = Context.Tokenize(text).ToList();
// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);

lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens; i++)
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
if (cancellationToken.IsCancellationRequested)
break;

var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;

// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);

lastTokens.Add(id);

// Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
yield return decoded;

tokens.Clear();
tokens.Add(id);

// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
break;

lastTokens.Add(id);
tokens.Clear();
tokens.Add(id);

// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)


+ 15
- 0
LLama/LLamaWeights.cs View File

@@ -38,6 +38,21 @@ namespace LLama
/// </summary>
public ulong ParameterCount => NativeHandle.ParameterCount;

/// <summary>
/// Get the newline token for this model
/// </summary>
public int NewlineToken => NativeApi.llama_token_nl(NativeHandle);

/// <summary>
/// Get the "end of sentence" token for this model
/// </summary>
public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);

/// <summary>
/// Get the "beginning of sentence" token for this model
/// </summary>
public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);

/// <summary>
/// Dimension of embedding vectors
/// </summary>


+ 32
- 3
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle
/// <summary>
/// Get the native llama_batch struct
/// </summary>
public LLamaNativeBatch NativeBatch { get; private set; }
public LLamaNativeBatch NativeBatch;

/// <summary>
/// the token ids of the input (used when embd is NULL)
@@ -113,10 +113,11 @@ public sealed class LLamaBatchSafeHandle
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <param name="n_seq_max"></param>
/// <returns></returns>
public static LLamaBatchSafeHandle Create(int n_tokens, int embd)
public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd);
var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
return new LLamaBatchSafeHandle(batch, embd);
}

@@ -128,4 +129,32 @@ public sealed class LLamaBatchSafeHandle
SetHandle(IntPtr.Zero);
return true;
}

/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
unsafe
{
NativeBatch.token[NativeBatch.n_tokens] = token;
NativeBatch.pos[NativeBatch.n_tokens] = pos;
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;

for (var i = 0; i < sequences.Length; i++)
NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];

NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);

NativeBatch.n_tokens++;
}
}

/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
/// </summary>
public void LLamaBatchClear()
{
NativeBatch.n_tokens = 0;
}
}

+ 3
- 3
LLama/Native/LLamaBeamView.cs View File

@@ -11,13 +11,13 @@ using llama_token = Int32;
[StructLayout(LayoutKind.Sequential)]
public struct LLamaBeamView
{
private readonly unsafe llama_token* tokens;
private readonly nint n_tokens;
private unsafe llama_token* tokens;
private nint n_tokens;

/// <summary>
/// Cumulative beam probability (renormalized relative to all beams)
/// </summary>
public readonly float CumulativeProbability;
public float CumulativeProbability;

/// <summary>
/// Callback should set this to true when a beam is at end-of-beam.


+ 5
- 5
LLama/Native/LLamaBeamsState.cs View File

@@ -9,27 +9,27 @@ namespace LLama.Native;
/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly struct LLamaBeamsState
public struct LLamaBeamsState
{
/// <summary>
/// The state of each individual beam
/// </summary>
private readonly unsafe LLamaBeamView* beam_views;
private unsafe LLamaBeamView* beam_views;

/// <summary>
/// Number of elements in beam_views
/// </summary>
private readonly nint n_beams;
private nint n_beams;

/// <summary>
/// Current max length of prefix tokens shared by all beams.
/// </summary>
public readonly ulong CommonPrefixLength;
public ulong CommonPrefixLength;

/// <summary>
/// True iff this is the last callback invocation.
/// </summary>
public readonly bool LastCall;
public bool LastCall;

/// <summary>
/// The current state of each beam


+ 3
- 3
LLama/Native/LLamaGrammarElement.cs View File

@@ -52,18 +52,18 @@ namespace LLama.Native
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
public readonly struct LLamaGrammarElement
public struct LLamaGrammarElement
: IEquatable<LLamaGrammarElement>
{
/// <summary>
/// The type of this element
/// </summary>
public readonly LLamaGrammarElementType Type;
public LLamaGrammarElementType Type;

/// <summary>
/// Unicode code point or rule ID
/// </summary>
public readonly uint Value;
public uint Value;

/// <summary>
/// Construct a new LLamaGrammarElement


+ 2
- 0
LLama/Native/LLamaModelQuantizeParams.cs View File

@@ -1,10 +1,12 @@
using System;
using System.Runtime.InteropServices;

namespace LLama.Native
{
/// <summary>
/// Quantizer parameters used in the native API
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaModelQuantizeParams
{
/// <summary>


+ 19
- 7
LLama/Native/LLamaNativeBatch.cs View File

@@ -11,35 +11,47 @@ using llama_token = Int32;
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct LLamaNativeBatch
public unsafe struct LLamaNativeBatch
{
/// <summary>
/// The number of items pointed at by pos, seq_id and logits.
/// </summary>
public readonly int n_tokens;
public int n_tokens;

/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public readonly llama_token* token;
public llama_token* token;

/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
/// </summary>
public readonly float* embd;
public float* embd;

/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public readonly LLamaPos* pos;
public LLamaPos* pos;

/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ???
/// </summary>
public int* n_seq_id;

/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public readonly LLamaSeqId* seq_id;
public LLamaSeqId** seq_id;

/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public readonly byte* logits;
public byte* logits;

// Note from llama.cpp:
// > helpers for smooth API transition - can be deprecated in the future
// > for future-proof code, use the above fields instead and ignore everything below
private LLamaPos _all_pos_0;
private LLamaPos _all_pos_1;
private LLamaSeqId _all_seq_id;
}

+ 15
- 3
LLama/Native/LLamaPos.cs View File

@@ -1,14 +1,26 @@
namespace LLama.Native;
using System.Runtime.InteropServices;

namespace LLama.Native;

/// <summary>
/// Indicates position in a sequence
/// </summary>
public readonly record struct LLamaPos(int Value)
[StructLayout(LayoutKind.Sequential)]
public struct LLamaPos
{
/// <summary>
/// The raw value
/// </summary>
public readonly int Value = Value;
public int Value;

/// <summary>
/// Create a new LLamaPos
/// </summary>
/// <param name="value"></param>
public LLamaPos(int value)
{
Value = value;
}

/// <summary>
/// Convert a LLamaPos into an integer (extract the raw value)


+ 15
- 4
LLama/Native/LLamaSeqId.cs View File

@@ -1,15 +1,26 @@
namespace LLama.Native;
using System.Runtime.InteropServices;

namespace LLama.Native;

/// <summary>
/// ID for a sequence in a batch
/// </summary>
/// <param name="Value"></param>
public record struct LLamaSeqId(int Value)
[StructLayout(LayoutKind.Sequential)]
public struct LLamaSeqId
{
/// <summary>
/// The raw value
/// </summary>
public int Value = Value;
public int Value;

/// <summary>
/// Create a new LLamaSeqId
/// </summary>
/// <param name="value"></param>
public LLamaSeqId(int value)
{
Value = value;
}

/// <summary>
/// Convert a LLamaSeqId into an integer (extract the raw value)


+ 17
- 7
LLama/Native/LLamaTokenData.cs View File

@@ -5,24 +5,34 @@ namespace LLama.Native;
/// <summary>
/// A single token along with probability of this token being selected
/// </summary>
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
[StructLayout(LayoutKind.Sequential)]
public record struct LLamaTokenData(int id, float logit, float p)
public struct LLamaTokenData
{
/// <summary>
/// token id
/// </summary>
public int id = id;
public int id;

/// <summary>
/// log-odds of the token
/// </summary>
public float logit = logit;
public float logit;

/// <summary>
/// probability of the token
/// </summary>
public float p = p;
public float p;

/// <summary>
/// Create a new LLamaTokenData
/// </summary>
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
public LLamaTokenData(int id, float logit, float p)
{
this.id = id;
this.logit = logit;
this.p = p;
}
}

+ 193
- 0
LLama/Native/LLamaTokenDataArray.cs View File

@@ -45,6 +45,199 @@ namespace LLama.Native

return new LLamaTokenDataArray(candidates);
}

#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
sorted = st.sorted;
}
}

/// <summary>
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
/// <param name="context"></param>
/// <param name="k">Number of tokens to keep</param>
/// <param name="minKeep">Minimum number to keep</param>
public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="minKeep"></param>
public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
/// <param name="context"></param>
/// <param name="z"></param>
/// <param name="min_keep"></param>
public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
sorted = st.sorted;
}
}

/// <summary>
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_typical(context, ref st, p, min_keep);
sorted = st.sorted;
}
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="context"></param>
/// <param name="last_tokens"></param>
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
using (var last_tokens_handle = last_tokens.Pin())
{
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
}
}
}

/// <summary>
/// Sample with temperature.
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
/// </summary>
/// <param name="context"></param>
/// <param name="temp"></param>
public void Temperature(SafeLLamaContextHandle context, float temp)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_temperature(context, ref st, temp);
sorted = st.sorted;
}
}

/// <summary>
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// </summary>
/// <param name="context"></param>
public void Softmax(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_softmax(context, ref st);
sorted = st.sorted;
}
}

/// <summary>
/// Randomly selects a token from the candidates based on their probabilities.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleToken(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token(context, ref st);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Selects the token with the highest probability.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public int SampleTokenGreedy(SafeLLamaContextHandle context)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_greedy(context, ref st);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// </summary>
/// <param name="context"></param>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
sorted = st.sorted;
return token;
}
}

/// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// </summary>
/// <param name="context"></param>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
sorted = st.sorted;
return token;
}
}
#endregion
}

/// <summary>


+ 9
- 13
LLama/Native/NativeApi.Sampling.cs View File

@@ -9,26 +9,22 @@ namespace LLama.Native
{
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
/// <param name="penalty_repeat">Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.</param>
/// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
/// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
ref LLamaTokenDataArrayNative candidates,
llama_token* last_tokens, ulong last_tokens_size,
float penalty_repeat,
float penalty_freq,
float penalty_present);

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806


+ 8
- 5
LLama/Native/NativeApi.cs View File

@@ -348,9 +348,10 @@ namespace LLama.Native
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
/// </summary>
/// <param name="ctx"></param>
/// <param name="i"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);

/// <summary>
/// Get the embeddings for the input
@@ -366,21 +367,21 @@ namespace LLama.Native
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_bos(SafeLlamaModelHandle model);

/// <summary>
/// Get the "End of sentence" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_eos(SafeLlamaModelHandle model);

/// <summary>
/// Get the "new line" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);

/// <summary>
/// Print out timing information for this context
@@ -530,6 +531,7 @@ namespace LLama.Native

/// <summary>
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
/// The batch has to be freed with llama_batch_free()
/// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@@ -538,8 +540,9 @@ namespace LLama.Native
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <param name="n_seq_max">Each token can be assigned up to n_seq_max sequence ids</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd);
public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd, int n_seq_max);

/// <summary>
/// Frees a batch of tokens allocated with llama_batch_init()


+ 16
- 0
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -114,6 +114,22 @@ namespace LLama.Native
}
}

/// <summary>
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
/// </summary>
/// <param name="i"></param>
/// <returns></returns>
public Span<float> GetLogitsIth(int i)
{
var model = ThrowIfDisposed();

unsafe
{
var logits = NativeApi.llama_get_logits_ith(this, i);
return new Span<float>(logits, model.VocabCount);
}
}

#region tokens
/// <summary>
/// Convert the given text into tokens


+ 10
- 0
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -102,5 +102,15 @@ namespace LLama.Native
return new(grammar_ptr);
}
#endregion

/// <summary>
/// Accepts the sampled token into the grammar
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
NativeApi.llama_grammar_accept_token(ctx, this, token);
}
}
}

+ 23
- 83
LLama/Native/SamplingApi.cs View File

@@ -9,7 +9,7 @@ namespace LLama.Native
/// <summary>
/// Direct translation of the llama.cpp sampling API
/// </summary>
public unsafe class SamplingApi
public class SamplingApi
{
/// <summary>
/// Apply grammar rules to candidate tokens
@@ -17,70 +17,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="grammar"></param>
[Obsolete("use LLamaTokenDataArray ApplyGrammar method")]
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
{
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
}

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
}

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
candidates.ApplyGrammar(ctx, grammar);
}

/// <summary>
@@ -88,10 +28,10 @@ namespace LLama.Native
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
[Obsolete("use LLamaTokenDataArray Softmax method")]
public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_softmax(ctx, ref st);
candidates.Softmax(ctx);
}

/// <summary>
@@ -101,10 +41,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="k"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TopK method")]
public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
candidates.TopK(ctx, k, min_keep);
}

/// <summary>
@@ -114,10 +54,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TopP method")]
public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
candidates.TopP(ctx, p, min_keep);
}

/// <summary>
@@ -127,10 +67,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="z"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray TailFree method")]
public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
candidates.TailFree(ctx, z, min_keep);
}

/// <summary>
@@ -140,10 +80,10 @@ namespace LLama.Native
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
[Obsolete("use LLamaTokenDataArray LocallyTypical method")]
public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
candidates.LocallyTypical(ctx, p, min_keep);
}

/// <summary>
@@ -153,10 +93,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="temp"></param>
[Obsolete("use LLamaTokenDataArray Temperature() method")]
public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_temperature(ctx, ref st, temp);
candidates.Temperature(ctx, temp);
}

/// <summary>
@@ -169,10 +109,10 @@ namespace LLama.Native
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
}

/// <summary>
@@ -184,10 +124,10 @@ namespace LLama.Native
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
}

/// <summary>
@@ -196,10 +136,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token_greedy(ctx, ref st);
return candidates.SampleTokenGreedy(ctx);
}

/// <summary>
@@ -208,10 +148,10 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <returns></returns>
[Obsolete("use LLamaTokenDataArray SampleToken() method")]
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
return NativeApi.llama_sample_token(ctx, ref st);
return candidates.SampleToken(ctx);
}
}
}

+ 171
- 2
LLama/runtimes/ggml-metal.metal View File

@@ -18,6 +18,21 @@ typedef struct {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;

#define QK5_0 32
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;

#define QK5_1 32
typedef struct {
half d; // delta
half m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;

#define QK8_0 32
typedef struct {
half d; // delta
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
}

kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}

kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
constant float & scale,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
@@ -399,8 +422,11 @@ kernel void kernel_rms_norm(
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1]) + sumy * m;
}

// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q5 quants begin (0 or QK5_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
return d * (sumy * -16.f + acc[0] + acc[1]);
}

// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q5 quants begin (0 or QK5_1/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
return d * (acc[0] + acc[1]) + sumy * m;
}

// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32(
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mv_q5_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mv_q5_1_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}


#define NB_Q8_0 8

kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
}
}

template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;

const uint32_t qh = *((device const uint32_t *)xb->qh);

const int x_mv = il ? 4 : 0;

const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;

for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;

// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);

reg[i/2][2*(i%2)+0] = d * x0 + md;
reg[i/2][2*(i%2)+1] = d * x1 + md;
}
}

template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;

const uint32_t qh = *((device const uint32_t *)xb->qh);

const int x_mv = il ? 4 : 0;

const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;

for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;

// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);

reg[i/2][2*(i%2)+0] = d * x0 + m;
reg[i/2][2*(i%2)+1] = d * x1 + m;
}
}

template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;


BIN
LLama/runtimes/libllama-cuda11.dll View File


BIN
LLama/runtimes/libllama-cuda11.so View File


BIN
LLama/runtimes/libllama-cuda12.dll View File


BIN
LLama/runtimes/libllama-cuda12.so View File


BIN
LLama/runtimes/libllama.dll View File


BIN
LLama/runtimes/libllama.dylib View File


BIN
LLama/runtimes/libllama.so View File


Loading…
Cancel
Save