using System.Diagnostics; using System.Security.Cryptography; using System.Text; using LLama.Common; using LLama.Native; namespace LLama.Examples.NewVersion; /// /// This demonstrates generating multiple replies to the same prompt, with a shared cache /// /// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this! public class BatchedDecoding { private const int n_parallel = 8; private const int n_len = 32; private const int top_k = 40; private const float top_p = 0.9f; private const float temp = 0.4f; public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); Console.WriteLine("Prompt (leave blank to select automatically):"); var prompt = Console.ReadLine(); if (string.IsNullOrWhiteSpace(prompt)) prompt = "Not many people know that"; // Load model var parameters = new ModelParams(modelPath); using var model = LLamaWeights.LoadFromFile(parameters); // Tokenize prompt var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; // Create a context parameters.ContextSize = (uint)model.ContextSize; parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue)); parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); using var context = model.CreateContext(parameters); var n_ctx = context.ContextSize; // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); return; } using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false); Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length); // llama_decode will output logits only for the last token of the prompt unsafe { batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; } if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) { await Console.Error.WriteLineAsync("llama_decode failed"); return; } // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (var i = 1; i < n_parallel; ++i) { NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); } if (n_parallel > 1) { Console.WriteLine(); Console.WriteLine($"generating {n_parallel} sequences..."); } // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from List i_batch = new(); for (var i = 0; i < n_parallel; i++) i_batch.Add(batch.NativeBatch.n_tokens - 1); int n_cur = batch.NativeBatch.n_tokens; int n_decode = 0; var streams = new List[n_parallel]; for (var i = 0; i < n_parallel; i++) streams[i] = new(); var eos = model.EndOfSentenceToken; var nl = model.NewlineToken; var timer = new Stopwatch(); timer.Start(); while (n_cur <= n_len) { llama_batch_clear(batch); for (var i = 0; i < n_parallel; i++) { // Skip completed streams if (i_batch[i] < 0) continue; var n_vocab = model.VocabCount; LLamaTokenDataArray candidates; unsafe { candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); } candidates.TopK(context.NativeHandle, top_k); candidates.TopP(context.NativeHandle, top_p); candidates.Temperature(context.NativeHandle, temp); var new_token_id = candidates.SampleToken(context.NativeHandle); if (new_token_id == eos || new_token_id == nl) { i_batch[i] = -1; Console.WriteLine($"Completed Stream {i} early"); continue; } streams[i].Add(new_token_id); i_batch[i] = batch.NativeBatch.n_tokens; // push this new token for next evaluation llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true); n_decode++; } // all streams are finished if (batch.NativeBatch.n_tokens == 0) { break; } n_cur++; // evaluate the current batch with the transformer model if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) { await Console.Error.WriteLineAsync("failed to eval"); return; } } timer.Stop(); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine(); Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); var index = 0; foreach (var stream in streams) { var text = context.DeTokenize(stream); Console.ForegroundColor = ConsoleColor.Green; Console.Write($"{index++}. {prompt}"); Console.ForegroundColor = ConsoleColor.Red; Console.WriteLine(text); } } /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List sequences, bool logits) { unsafe { ref var batch = ref batchHandle.NativeBatch; batch.token[batch.n_tokens] = token; batch.pos[batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = sequences.Count; for (var i = 0; i < sequences.Count; i++) batch.seq_id[batch.n_tokens][i] = sequences[i]; batch.logits[batch.n_tokens] = Convert.ToByte(logits); batch.n_tokens++; } } /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 /// /// private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle) { batchHandle.NativeBatch.n_tokens = 0; } }