using System.Diagnostics;
using System.Text;
using LLama.Common;
using LLama.Native;
namespace LLama.Examples.Examples;
///
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
///
/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!
public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;
private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.75f;
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";
// Load model
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
// Tokenize prompt
var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);
var n_ctx = context.ContextSize;
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}
var batch = new LLamaBatch();
// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}
// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}
if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.TokenCount - 1);
var n_cur = batch.TokenCount;
var n_decode = 0;
var streams = new StreamingTokenDecoder[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new StreamingTokenDecoder(context);
var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;
var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
batch.Clear();
for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;
var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);
if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}
streams[i].Add(new_token_id);
i_batch[i] = batch.TokenCount;
// push this new token for next evaluation
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
n_decode++;
}
// all streams are finished
if (batch.TokenCount == 0)
{
break;
}
n_cur++;
// evaluate the current batch with the transformer model
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}
timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
var index = 0;
foreach (var stream in streams)
{
var text = stream.Read();
Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}
Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}
}