You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorFork.md 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Bacthed executor - multi-output to one input
  2. ```cs
  3. using LLama.Batched;
  4. using LLama.Common;
  5. using LLama.Native;
  6. using LLama.Sampling;
  7. using Spectre.Console;
  8. namespace LLama.Examples.Examples;
  9. /// <summary>
  10. /// This demonstrates generating multiple replies to the same prompt, with a shared cache
  11. /// </summary>
  12. public class BatchedExecutorFork
  13. {
  14. private const int n_split = 16;
  15. private const int n_len = 72;
  16. public static async Task Run()
  17. {
  18. string modelPath = UserSettings.GetModelPath();
  19. var parameters = new ModelParams(modelPath);
  20. using var model = LLamaWeights.LoadFromFile(parameters);
  21. var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
  22. // Create an executor that can evaluate a batch of conversations together
  23. using var executor = new BatchedExecutor(model, parameters);
  24. // Print some info
  25. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  26. Console.WriteLine($"Created executor with model: {name}");
  27. // Evaluate the initial prompt to create one conversation
  28. using var start = executor.Create();
  29. start.Prompt(prompt);
  30. await executor.Infer();
  31. // Create the root node of the tree
  32. var root = new Node(start);
  33. await AnsiConsole
  34. .Progress()
  35. .StartAsync(async progress =>
  36. {
  37. var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
  38. // Run inference loop
  39. for (var i = 0; i < n_len; i++)
  40. {
  41. if (i != 0)
  42. await executor.Infer();
  43. // Occasionally fork all the active conversations
  44. if (i != 0 && i % n_split == 0)
  45. root.Split();
  46. // Sample all active conversations
  47. root.Sample();
  48. // Update progress bar
  49. reporter.Increment(1);
  50. reporter.Description($"Running Inference ({root.ActiveConversationCount})");
  51. }
  52. // Display results
  53. var display = new Tree(prompt);
  54. root.Display(display);
  55. AnsiConsole.Write(display);
  56. });
  57. }
  58. private class Node
  59. {
  60. private readonly StreamingTokenDecoder _decoder;
  61. private readonly DefaultSamplingPipeline _sampler;
  62. private Conversation? _conversation;
  63. private Node? _left;
  64. private Node? _right;
  65. public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
  66. public Node(Conversation conversation)
  67. {
  68. _sampler = new DefaultSamplingPipeline();
  69. _conversation = conversation;
  70. _decoder = new StreamingTokenDecoder(conversation.Executor.Context);
  71. }
  72. public void Sample()
  73. {
  74. if (_conversation == null)
  75. {
  76. _left?.Sample();
  77. _right?.Sample();
  78. return;
  79. }
  80. if (_conversation.RequiresInference)
  81. return;
  82. // Sample one token
  83. var ctx = _conversation.Executor.Context.NativeHandle;
  84. var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
  85. _sampler.Accept(ctx, token);
  86. _decoder.Add(token);
  87. // Prompt the conversation with this token, to continue generating from there
  88. _conversation.Prompt(token);
  89. }
  90. public void Split()
  91. {
  92. if (_conversation != null)
  93. {
  94. _left = new Node(_conversation.Fork());
  95. _right = new Node(_conversation.Fork());
  96. _conversation.Dispose();
  97. _conversation = null;
  98. }
  99. else
  100. {
  101. _left?.Split();
  102. _right?.Split();
  103. }
  104. }
  105. public void Display<T>(T tree, int depth = 0)
  106. where T : IHasTreeNodes
  107. {
  108. var colors = new[] { "red", "green", "blue", "yellow", "white" };
  109. var color = colors[depth % colors.Length];
  110. var message = Markup.Escape(_decoder.Read().ReplaceLineEndings(""));
  111. var n = tree.AddNode($"[{color}]{message}[/]");
  112. _left?.Display(n, depth + 1);
  113. _right?.Display(n, depth + 1);
  114. }
  115. }
  116. }
  117. ```