You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorRewind.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. using Spectre.Console;
  6. namespace LLama.Examples.Examples;
  7. /// <summary>
  8. /// This demonstrates generating tokens and then rewinding to an earlier state
  9. /// </summary>
  10. public class BatchedExecutorRewind
  11. {
  12. private const int n_generate = 24;
  13. private const int n_rewind = 12;
  14. private const int n_repeats = 6;
  15. public static async Task Run()
  16. {
  17. string modelPath = UserSettings.GetModelPath();
  18. var parameters = new ModelParams(modelPath);
  19. using var model = LLamaWeights.LoadFromFile(parameters);
  20. var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
  21. // Create an executor that can evaluate a batch of conversations together
  22. using var executor = new BatchedExecutor(model, parameters);
  23. // Print some info
  24. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  25. Console.WriteLine($"Created executor with model: {name}");
  26. // Evaluate the initial prompt to create one conversation
  27. using var conversation = executor.Prompt(prompt);
  28. // Create the start node wrapping the conversation
  29. var node = new Node(executor.Context);
  30. // Print the prompt
  31. Console.ForegroundColor = ConsoleColor.Green;
  32. Console.WriteLine(prompt);
  33. for (var i = 0; i < n_repeats; i++)
  34. {
  35. for (var j = 0; j < n_generate; j++)
  36. {
  37. // Run inference
  38. await executor.Infer();
  39. // Sample a token
  40. var token = node.Sample(conversation);
  41. // Continue conversation with this token
  42. if (j != n_generate - 1)
  43. conversation.Prompt(token);
  44. }
  45. // Write out what we generated
  46. node.Write(n_rewind, i + 1);
  47. // Rewind back a few tokens
  48. conversation.Rewind(n_rewind + 1);
  49. // Prompt with a token
  50. conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
  51. // Create a new node around the rewound conversation
  52. node = new Node(executor.Context);
  53. }
  54. Console.WriteLine("Press any key to exit demo");
  55. Console.ReadKey(true);
  56. }
  57. private class Node
  58. {
  59. private readonly LLamaContext _context;
  60. private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
  61. private readonly DefaultSamplingPipeline Sampler;
  62. public Node(LLamaContext context)
  63. {
  64. _context = context;
  65. Sampler = new DefaultSamplingPipeline();
  66. }
  67. public LLamaToken Sample(Conversation conversation)
  68. {
  69. var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
  70. _tokens.Add(token);
  71. return token;
  72. }
  73. public void Write(int n_rewind, int depth)
  74. {
  75. var decoder = new StreamingTokenDecoder(_context);
  76. for (var i = 0; i < _tokens.Count - n_rewind; i++)
  77. decoder.Add(_tokens[i]);
  78. AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");
  79. for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
  80. decoder.Add(_tokens[i]);
  81. AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
  82. }
  83. public LLamaToken GetToken(int index)
  84. {
  85. return _tokens[index];
  86. }
  87. }
  88. }