using LLama.Batched; using LLama.Common; using LLama.Native; using LLama.Sampling; using Spectre.Console; namespace LLama.Examples.Examples; /// /// This demonstrates generating multiple replies to the same prompt, with a shared cache /// public class BatchedExecutorFork { private const int n_split = 16; private const int n_len = 72; public static async Task Run() { string modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath); using var model = LLamaWeights.LoadFromFile(parameters); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); // Create an executor that can evaluate a batch of conversations together using var executor = new BatchedExecutor(model, parameters); // Print some info var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); Console.WriteLine($"Created executor with model: {name}"); // Evaluate the initial prompt to create one conversation using var start = executor.Create(); start.Prompt(prompt); await executor.Infer(); // Create the root node of the tree var root = new Node(start); await AnsiConsole .Progress() .StartAsync(async progress => { var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len); // Run inference loop for (var i = 0; i < n_len; i++) { if (i != 0) await executor.Infer(); // Occasionally fork all the active conversations if (i != 0 && i % n_split == 0) root.Split(); // Sample all active conversations root.Sample(); // Update progress bar reporter.Increment(1); reporter.Description($"Running Inference ({root.ActiveConversationCount})"); } // Display results var display = new Tree(prompt); root.Display(display); AnsiConsole.Write(display); }); } private class Node { private readonly StreamingTokenDecoder _decoder; private readonly DefaultSamplingPipeline _sampler; private Conversation? _conversation; private Node? _left; private Node? _right; public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount; public Node(Conversation conversation) { _sampler = new DefaultSamplingPipeline(); _conversation = conversation; _decoder = new StreamingTokenDecoder(conversation.Executor.Context); } public void Sample() { if (_conversation == null) { _left?.Sample(); _right?.Sample(); return; } if (_conversation.RequiresInference) return; // Sample one token var ctx = _conversation.Executor.Context.NativeHandle; var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty()); _sampler.Accept(ctx, token); _decoder.Add(token); // Prompt the conversation with this token, to continue generating from there _conversation.Prompt(token); } public void Split() { if (_conversation != null) { _left = new Node(_conversation.Fork()); _right = new Node(_conversation.Fork()); _conversation.Dispose(); _conversation = null; } else { _left?.Split(); _right?.Split(); } } public void Display(T tree, int depth = 0) where T : IHasTreeNodes { var colors = new[] { "red", "green", "blue", "yellow", "white" }; var color = colors[depth % colors.Length]; var message = Markup.Escape(_decoder.Read().ReplaceLineEndings("")); var n = tree.AddNode($"[{color}]{message}[/]"); _left?.Display(n, depth + 1); _right?.Display(n, depth + 1); } } }