diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index a4ebb604..7fea4562 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -8,6 +8,7 @@ AnyCPU;x64 true + true diff --git a/LLama.Examples/NewVersion/BatchedBench.cs b/LLama.Examples/NewVersion/BatchedBench.cs deleted file mode 100644 index 5b21ff72..00000000 --- a/LLama.Examples/NewVersion/BatchedBench.cs +++ /dev/null @@ -1,25 +0,0 @@ -using LLama.Common; -using LLama.Native; - -namespace LLama.Examples.NewVersion; - -public class BatchedBench -{ - public static async Task Run() - { - Console.Write("Please input your model path: "); - //todo:var modelPath = Console.ReadLine(); - var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf"; - - var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); - - parameters.ContextSize = (uint)model.ContextSize; - using var context = model.CreateContext(parameters); - - var n_kv_max = 1024; - - using var batch = LLamaBatchSafeHandle.Create(n_kv_max, 0, 1); - - } -} \ No newline at end of file diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs new file mode 100644 index 00000000..7d140e81 --- /dev/null +++ b/LLama.Examples/NewVersion/BatchedDecoding.cs @@ -0,0 +1,219 @@ +using System.Diagnostics; +using System.Security.Cryptography; +using System.Text; +using LLama.Common; +using LLama.Native; + +namespace LLama.Examples.NewVersion; + +public class BatchedDecoding +{ + private const int n_parallel = 8; + private const int n_len = 32; + + public static async Task Run() + { + Console.Write("Please input your model path: "); + //todo:var modelPath = Console.ReadLine(); + var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf"; + + Console.WriteLine("Prompt (leave blank to select automatically):"); + var prompt = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(prompt)) + prompt = "I would like to tell you about"; + + // Load model + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + // Tokenize prompt + var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); + var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; + + // Create a context + parameters.ContextSize = (uint)model.ContextSize; + parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue)); + parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); + using var context = model.CreateContext(parameters); + + var n_ctx = context.ContextSize; + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) + { + await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); + await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); + return; + } + + using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); + + // evaluate the initial prompt + for (var i = 0; i < prompt_tokens.Length; i++) + llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false); + Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length); + + // llama_decode will output logits only for the last token of the prompt + unsafe + { + batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; + } + + if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) + { + await Console.Error.WriteLineAsync("llama_decode failed"); + return; + } + + // assign the system KV cache to all parallel sequences + // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + for (var i = 1; i < n_parallel; ++i) + { + NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); + } + + if (n_parallel > 1) + { + Console.WriteLine(); + Console.WriteLine($"generating {n_parallel} sequences..."); + } + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + List i_batch = new(); + for (var i = 0; i < n_parallel; i++) + i_batch.Add(batch.NativeBatch.n_tokens - 1); + + int n_cur = batch.NativeBatch.n_tokens; + int n_decode = 0; + + var streams = new List[n_parallel]; + for (var i = 0; i < n_parallel; i++) + streams[i] = new(); + + var eos = NativeApi.llama_token_eos(model.NativeHandle); + var nl = NativeApi.llama_token_nl(model.NativeHandle); + + var timer = new Stopwatch(); + timer.Start(); + while (n_cur <= n_len) + { + llama_batch_clear(batch); + + for (var i = 0; i < n_parallel; i++) + { + // Skip completed streams + if (i_batch[i] < 0) + continue; + + unsafe + { + var n_vocab = model.VocabCount; + var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]); + + var candidates = new LLamaTokenData[n_vocab]; + for (var token_id = 0; token_id < n_vocab; token_id++) + { + candidates[token_id] = new LLamaTokenData + { + id = token_id, + logit = logits[token_id] + }; + } + + var candidates_p = new LLamaTokenDataArray(candidates); + using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native); + + const int top_k = 40; + const float top_p = 0.9f; + const float temp = 0.4f; + + NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1); + NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1); + NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp); + + var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native); + + if (new_token_id == eos || new_token_id == nl) + { + i_batch[i] = -1; + Console.WriteLine($"Completed Stream {i} early"); + continue; + } + + streams[i].Add(new_token_id); + + i_batch[i] = batch.NativeBatch.n_tokens; + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true); + + n_decode++; + } + } + + // all streams are finished + if (batch.NativeBatch.n_tokens == 0) + { + break; + } + + n_cur++; + + // evaluate the current batch with the transformer model + if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) + { + await Console.Error.WriteLineAsync("failed to eval"); + return; + } + } + + timer.Stop(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine(); + Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); + Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); + + var index = 0; + foreach (var stream in streams) + { + var text = context.DeTokenize(stream); + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write($"{index++}. {prompt}"); + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine(text); + } + } + + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 + /// + private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List sequences, bool logits) + { + unsafe + { + ref var batch = ref batchHandle.NativeBatch; + + batch.token[batch.n_tokens] = token; + batch.pos[batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = sequences.Count; + + for (var i = 0; i < sequences.Count; i++) + batch.seq_id[batch.n_tokens][i] = sequences[i]; + + batch.logits[batch.n_tokens] = Convert.ToByte(logits); + + batch.n_tokens++; + } + } + + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 + /// + /// + private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle) + { + batchHandle.NativeBatch.n_tokens = 0; + } +} \ No newline at end of file diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 22f51e4b..231a67ca 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -91,7 +91,7 @@ } else if (choice == 15) { - await BatchedBench.Run(); + await BatchedDecoding.Run(); } else { diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 47240dc7..2962cb69 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -305,7 +305,7 @@ namespace LLama } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(NativeHandle); + var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); var nl_logit = logits[nl_token]; // Convert logits into token candidates diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9e4292ea..80c6f542 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -163,7 +163,7 @@ namespace LLama } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { args.WaitForInput = true; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index d3d4a9e3..5078648b 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -30,7 +30,7 @@ namespace LLama public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { - _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); + _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); } /// @@ -141,7 +141,7 @@ namespace LLama return (true, Array.Empty()); } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { return (true, new[] { " [end of text]\n" }); } @@ -202,7 +202,7 @@ namespace LLama _last_n_tokens.Enqueue(id); - if (id == NativeApi.llama_token_eos(Context.NativeHandle)) + if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { id = _llama_token_newline; if (args.Antiprompts is not null && args.Antiprompts.Count > 0) diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs index f21c70ea..6b8ec0b6 100644 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle /// /// Get the native llama_batch struct /// - public LLamaNativeBatch NativeBatch { get; private set; } + public LLamaNativeBatch NativeBatch; /// /// the token ids of the input (used when embd is NULL) diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs index 576f8b27..867b8c01 100644 --- a/LLama/Native/LLamaNativeBatch.cs +++ b/LLama/Native/LLamaNativeBatch.cs @@ -11,12 +11,12 @@ using llama_token = Int32; /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens /// [StructLayout(LayoutKind.Sequential)] -public readonly unsafe struct LLamaNativeBatch +public unsafe struct LLamaNativeBatch { /// /// The number of items pointed at by pos, seq_id and logits. /// - public readonly int n_tokens; + public int n_tokens; /// /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created @@ -33,13 +33,25 @@ public readonly unsafe struct LLamaNativeBatch /// public readonly LLamaPos* pos; + /// + /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ??? + /// + public readonly int* n_seq_id; + /// /// the sequence to which the respective token belongs /// - public readonly LLamaSeqId* seq_id; + public readonly LLamaSeqId** seq_id; /// /// if zero, the logits for the respective token will not be output /// public readonly byte* logits; + + // Note from llama.cpp: + // > helpers for smooth API transition - can be deprecated in the future + // > for future-proof code, use the above fields instead and ignore everything below + private LLamaPos _all_pos_0; + private LLamaPos _all_pos_1; + private LLamaSeqId _all_seq_id; } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index b239a775..3f193730 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -350,7 +350,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); + public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); /// /// Get the embeddings for the input @@ -366,21 +366,21 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); /// /// Get the "End of sentence" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); /// /// Get the "new line" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); /// /// Print out timing information for this context