You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorFork.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. namespace LLama.Examples.Examples;
  6. /// <summary>
  7. /// This demonstrates generating multiple replies to the same prompt, with a shared cache
  8. /// </summary>
  9. public class BatchedExecutorFork
  10. {
  11. private const int n_split = 16;
  12. private const int n_len = 64;
  13. public static async Task Run()
  14. {
  15. string modelPath = UserSettings.GetModelPath();
  16. var parameters = new ModelParams(modelPath);
  17. using var model = LLamaWeights.LoadFromFile(parameters);
  18. Console.WriteLine("Prompt (leave blank to select automatically):");
  19. var prompt = Console.ReadLine();
  20. if (string.IsNullOrWhiteSpace(prompt))
  21. prompt = "Not many people know that";
  22. // Create an executor that can evaluate a batch of conversations together
  23. var executor = new BatchedExecutor(model, parameters);
  24. // Print some info
  25. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  26. Console.WriteLine($"Created executor with model: {name}");
  27. // Evaluate the initial prompt to create one conversation
  28. var start = executor.Prompt(prompt);
  29. await executor.Infer();
  30. // Create the root node of the tree
  31. var root = new Node(start);
  32. // Run inference loop
  33. for (var i = 0; i < n_len; i++)
  34. {
  35. if (i != 0)
  36. await executor.Infer();
  37. // Occasionally fork all the active conversations
  38. if (i != 0 && i % n_split == 0)
  39. root.Split();
  40. // Sample all active conversations
  41. root.Sample();
  42. }
  43. Console.WriteLine($"{prompt}...");
  44. root.Print(1);
  45. Console.WriteLine("Press any key to exit demo");
  46. Console.ReadKey(true);
  47. }
  48. class Node
  49. {
  50. private readonly StreamingTokenDecoder _decoder;
  51. private readonly DefaultSamplingPipeline _sampler;
  52. private Conversation? _conversation;
  53. private Node? _left;
  54. private Node? _right;
  55. public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
  56. public Node(Conversation conversation)
  57. {
  58. _sampler = new DefaultSamplingPipeline();
  59. _conversation = conversation;
  60. _decoder = new StreamingTokenDecoder(conversation.Executor.Context);
  61. }
  62. public void Sample()
  63. {
  64. if (_conversation == null)
  65. {
  66. _left?.Sample();
  67. _right?.Sample();
  68. return;
  69. }
  70. if (_conversation.RequiresInference)
  71. return;
  72. // Sample one token
  73. var ctx = _conversation.Executor.Context.NativeHandle;
  74. var logitsCopy = _conversation.Sample().ToArray();
  75. var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
  76. _sampler.Accept(ctx, token);
  77. _decoder.Add(token);
  78. // Prompt the conversation with this token, to continue generating from there
  79. _conversation.Prompt(token);
  80. }
  81. public void Split()
  82. {
  83. if (_conversation != null)
  84. {
  85. _left = new Node(_conversation.Fork());
  86. _right = new Node(_conversation.Fork());
  87. _conversation.Dispose();
  88. _conversation = null;
  89. }
  90. else
  91. {
  92. _left?.Split();
  93. _right?.Split();
  94. }
  95. }
  96. public void Print(int indendation)
  97. {
  98. var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
  99. Console.ForegroundColor = colors[indendation % colors.Length];
  100. var message = _decoder.Read().ReplaceLineEndings("");
  101. var prefix = new string(' ', indendation * 3);
  102. var suffix = _conversation == null ? "..." : "";
  103. Console.WriteLine($"{prefix}...{message}{suffix}");
  104. _left?.Print(indendation + 2);
  105. _right?.Print(indendation + 2);
  106. }
  107. }
  108. }