diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index a4ebb604..7fea4562 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -8,6 +8,7 @@
AnyCPU;x64
true
+ true
diff --git a/LLama.Examples/NewVersion/BatchedBench.cs b/LLama.Examples/NewVersion/BatchedBench.cs
deleted file mode 100644
index 5b21ff72..00000000
--- a/LLama.Examples/NewVersion/BatchedBench.cs
+++ /dev/null
@@ -1,25 +0,0 @@
-using LLama.Common;
-using LLama.Native;
-
-namespace LLama.Examples.NewVersion;
-
-public class BatchedBench
-{
- public static async Task Run()
- {
- Console.Write("Please input your model path: ");
- //todo:var modelPath = Console.ReadLine();
- var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf";
-
- var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
-
- parameters.ContextSize = (uint)model.ContextSize;
- using var context = model.CreateContext(parameters);
-
- var n_kv_max = 1024;
-
- using var batch = LLamaBatchSafeHandle.Create(n_kv_max, 0, 1);
-
- }
-}
\ No newline at end of file
diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs
new file mode 100644
index 00000000..7d140e81
--- /dev/null
+++ b/LLama.Examples/NewVersion/BatchedDecoding.cs
@@ -0,0 +1,219 @@
+using System.Diagnostics;
+using System.Security.Cryptography;
+using System.Text;
+using LLama.Common;
+using LLama.Native;
+
+namespace LLama.Examples.NewVersion;
+
+public class BatchedDecoding
+{
+ private const int n_parallel = 8;
+ private const int n_len = 32;
+
+ public static async Task Run()
+ {
+ Console.Write("Please input your model path: ");
+ //todo:var modelPath = Console.ReadLine();
+ var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf";
+
+ Console.WriteLine("Prompt (leave blank to select automatically):");
+ var prompt = Console.ReadLine();
+ if (string.IsNullOrWhiteSpace(prompt))
+ prompt = "I would like to tell you about";
+
+ // Load model
+ var parameters = new ModelParams(modelPath);
+ using var model = LLamaWeights.LoadFromFile(parameters);
+
+ // Tokenize prompt
+ var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
+ var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
+
+ // Create a context
+ parameters.ContextSize = (uint)model.ContextSize;
+ parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue));
+ parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
+ using var context = model.CreateContext(parameters);
+
+ var n_ctx = context.ContextSize;
+
+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
+ if (n_kv_req > n_ctx)
+ {
+ await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
+ await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
+ return;
+ }
+
+ using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
+
+ // evaluate the initial prompt
+ for (var i = 0; i < prompt_tokens.Length; i++)
+ llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false);
+ Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length);
+
+ // llama_decode will output logits only for the last token of the prompt
+ unsafe
+ {
+ batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
+ }
+
+ if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
+ {
+ await Console.Error.WriteLineAsync("llama_decode failed");
+ return;
+ }
+
+ // assign the system KV cache to all parallel sequences
+ // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
+ for (var i = 1; i < n_parallel; ++i)
+ {
+ NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
+ }
+
+ if (n_parallel > 1)
+ {
+ Console.WriteLine();
+ Console.WriteLine($"generating {n_parallel} sequences...");
+ }
+
+ // remember the batch index of the last token for each parallel sequence
+ // we need this to determine which logits to sample from
+ List i_batch = new();
+ for (var i = 0; i < n_parallel; i++)
+ i_batch.Add(batch.NativeBatch.n_tokens - 1);
+
+ int n_cur = batch.NativeBatch.n_tokens;
+ int n_decode = 0;
+
+ var streams = new List[n_parallel];
+ for (var i = 0; i < n_parallel; i++)
+ streams[i] = new();
+
+ var eos = NativeApi.llama_token_eos(model.NativeHandle);
+ var nl = NativeApi.llama_token_nl(model.NativeHandle);
+
+ var timer = new Stopwatch();
+ timer.Start();
+ while (n_cur <= n_len)
+ {
+ llama_batch_clear(batch);
+
+ for (var i = 0; i < n_parallel; i++)
+ {
+ // Skip completed streams
+ if (i_batch[i] < 0)
+ continue;
+
+ unsafe
+ {
+ var n_vocab = model.VocabCount;
+ var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]);
+
+ var candidates = new LLamaTokenData[n_vocab];
+ for (var token_id = 0; token_id < n_vocab; token_id++)
+ {
+ candidates[token_id] = new LLamaTokenData
+ {
+ id = token_id,
+ logit = logits[token_id]
+ };
+ }
+
+ var candidates_p = new LLamaTokenDataArray(candidates);
+ using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native);
+
+ const int top_k = 40;
+ const float top_p = 0.9f;
+ const float temp = 0.4f;
+
+ NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1);
+ NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1);
+ NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp);
+
+ var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native);
+
+ if (new_token_id == eos || new_token_id == nl)
+ {
+ i_batch[i] = -1;
+ Console.WriteLine($"Completed Stream {i} early");
+ continue;
+ }
+
+ streams[i].Add(new_token_id);
+
+ i_batch[i] = batch.NativeBatch.n_tokens;
+
+ // push this new token for next evaluation
+ llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
+
+ n_decode++;
+ }
+ }
+
+ // all streams are finished
+ if (batch.NativeBatch.n_tokens == 0)
+ {
+ break;
+ }
+
+ n_cur++;
+
+ // evaluate the current batch with the transformer model
+ if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
+ {
+ await Console.Error.WriteLineAsync("failed to eval");
+ return;
+ }
+ }
+
+ timer.Stop();
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine();
+ Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
+ Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
+
+ var index = 0;
+ foreach (var stream in streams)
+ {
+ var text = context.DeTokenize(stream);
+
+ Console.ForegroundColor = ConsoleColor.Green;
+ Console.Write($"{index++}. {prompt}");
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine(text);
+ }
+ }
+
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ ///
+ private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List sequences, bool logits)
+ {
+ unsafe
+ {
+ ref var batch = ref batchHandle.NativeBatch;
+
+ batch.token[batch.n_tokens] = token;
+ batch.pos[batch.n_tokens] = pos;
+ batch.n_seq_id[batch.n_tokens] = sequences.Count;
+
+ for (var i = 0; i < sequences.Count; i++)
+ batch.seq_id[batch.n_tokens][i] = sequences[i];
+
+ batch.logits[batch.n_tokens] = Convert.ToByte(logits);
+
+ batch.n_tokens++;
+ }
+ }
+
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
+ ///
+ ///
+ private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle)
+ {
+ batchHandle.NativeBatch.n_tokens = 0;
+ }
+}
\ No newline at end of file
diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs
index 22f51e4b..231a67ca 100644
--- a/LLama.Examples/NewVersion/TestRunner.cs
+++ b/LLama.Examples/NewVersion/TestRunner.cs
@@ -91,7 +91,7 @@
}
else if (choice == 15)
{
- await BatchedBench.Run();
+ await BatchedDecoding.Run();
}
else
{
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 47240dc7..2962cb69 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -305,7 +305,7 @@ namespace LLama
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl(NativeHandle);
+ var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 9e4292ea..80c6f542 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -163,7 +163,7 @@ namespace LLama
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
args.WaitForInput = true;
}
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index d3d4a9e3..5078648b 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -30,7 +30,7 @@ namespace LLama
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
- _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle);
+ _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
}
///
@@ -141,7 +141,7 @@ namespace LLama
return (true, Array.Empty());
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
return (true, new[] { " [end of text]\n" });
}
@@ -202,7 +202,7 @@ namespace LLama
_last_n_tokens.Enqueue(id);
- if (id == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
id = _llama_token_newline;
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs
index f21c70ea..6b8ec0b6 100644
--- a/LLama/Native/LLamaBatchSafeHandle.cs
+++ b/LLama/Native/LLamaBatchSafeHandle.cs
@@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle
///
/// Get the native llama_batch struct
///
- public LLamaNativeBatch NativeBatch { get; private set; }
+ public LLamaNativeBatch NativeBatch;
///
/// the token ids of the input (used when embd is NULL)
diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs
index 576f8b27..867b8c01 100644
--- a/LLama/Native/LLamaNativeBatch.cs
+++ b/LLama/Native/LLamaNativeBatch.cs
@@ -11,12 +11,12 @@ using llama_token = Int32;
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
///
[StructLayout(LayoutKind.Sequential)]
-public readonly unsafe struct LLamaNativeBatch
+public unsafe struct LLamaNativeBatch
{
///
/// The number of items pointed at by pos, seq_id and logits.
///
- public readonly int n_tokens;
+ public int n_tokens;
///
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
@@ -33,13 +33,25 @@ public readonly unsafe struct LLamaNativeBatch
///
public readonly LLamaPos* pos;
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ???
+ ///
+ public readonly int* n_seq_id;
+
///
/// the sequence to which the respective token belongs
///
- public readonly LLamaSeqId* seq_id;
+ public readonly LLamaSeqId** seq_id;
///
/// if zero, the logits for the respective token will not be output
///
public readonly byte* logits;
+
+ // Note from llama.cpp:
+ // > helpers for smooth API transition - can be deprecated in the future
+ // > for future-proof code, use the above fields instead and ignore everything below
+ private LLamaPos _all_pos_0;
+ private LLamaPos _all_pos_1;
+ private LLamaSeqId _all_seq_id;
}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index b239a775..3f193730 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -350,7 +350,7 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);
+ public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
///
/// Get the embeddings for the input
@@ -366,21 +366,21 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_bos(SafeLlamaModelHandle model);
///
/// Get the "End of sentence" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_eos(SafeLlamaModelHandle model);
///
/// Get the "new line" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);
///
/// Print out timing information for this context