You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorRewind.md 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Batched executor - rewinding to an earlier state
  2. ```cs
  3. using LLama.Batched;
  4. using LLama.Common;
  5. using LLama.Native;
  6. using LLama.Sampling;
  7. using Spectre.Console;
  8. namespace LLama.Examples.Examples;
  9. /// <summary>
  10. /// This demonstrates generating tokens and then rewinding to an earlier state
  11. /// </summary>
  12. public class BatchedExecutorRewind
  13. {
  14. private const int n_generate = 24;
  15. private const int n_rewind = 12;
  16. private const int n_repeats = 6;
  17. public static async Task Run()
  18. {
  19. string modelPath = UserSettings.GetModelPath();
  20. var parameters = new ModelParams(modelPath);
  21. using var model = LLamaWeights.LoadFromFile(parameters);
  22. var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
  23. // Create an executor that can evaluate a batch of conversations together
  24. using var executor = new BatchedExecutor(model, parameters);
  25. // Print some info
  26. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  27. Console.WriteLine($"Created executor with model: {name}");
  28. // Evaluate the initial prompt to create one conversation
  29. using var conversation = executor.Create();
  30. conversation.Prompt(prompt);
  31. // Create the start node wrapping the conversation
  32. var node = new Node(executor.Context);
  33. // Print the prompt
  34. Console.ForegroundColor = ConsoleColor.Green;
  35. Console.WriteLine(prompt);
  36. for (var i = 0; i < n_repeats; i++)
  37. {
  38. for (var j = 0; j < n_generate; j++)
  39. {
  40. // Run inference
  41. await executor.Infer();
  42. // Sample a token
  43. var token = node.Sample(conversation);
  44. // Continue conversation with this token
  45. if (j != n_generate - 1)
  46. conversation.Prompt(token);
  47. }
  48. // Write out what we generated
  49. node.Write(n_rewind, i + 1);
  50. // Rewind back a few tokens
  51. conversation.Rewind(n_rewind + 1);
  52. // Prompt with a token
  53. conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
  54. // Create a new node around the rewound conversation
  55. node = new Node(executor.Context);
  56. }
  57. Console.WriteLine("Press any key to exit demo");
  58. Console.ReadKey(true);
  59. }
  60. private class Node
  61. {
  62. private readonly LLamaContext _context;
  63. private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
  64. private readonly DefaultSamplingPipeline Sampler;
  65. public Node(LLamaContext context)
  66. {
  67. _context = context;
  68. Sampler = new DefaultSamplingPipeline();
  69. }
  70. public LLamaToken Sample(Conversation conversation)
  71. {
  72. var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
  73. _tokens.Add(token);
  74. return token;
  75. }
  76. public void Write(int n_rewind, int depth)
  77. {
  78. var decoder = new StreamingTokenDecoder(_context);
  79. for (var i = 0; i < _tokens.Count - n_rewind; i++)
  80. decoder.Add(_tokens[i]);
  81. AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");
  82. for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
  83. decoder.Add(_tokens[i]);
  84. AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
  85. }
  86. public LLamaToken GetToken(int index)
  87. {
  88. return _tokens[index];
  89. }
  90. }
  91. }
  92. ```