You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorRewind.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. using Spectre.Console;
  6. namespace LLama.Examples.Examples;
  7. /// <summary>
  8. /// This demonstrates generating tokens and then rewinding to an earlier state
  9. /// </summary>
  10. public class BatchedExecutorRewind
  11. {
  12. private const int n_generate = 24;
  13. private const int n_rewind = 12;
  14. private const int n_repeats = 6;
  15. public static async Task Run()
  16. {
  17. string modelPath = UserSettings.GetModelPath();
  18. var parameters = new ModelParams(modelPath);
  19. using var model = await LLamaWeights.LoadFromFileAsync(parameters);
  20. var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
  21. // Create an executor that can evaluate a batch of conversations together
  22. using var executor = new BatchedExecutor(model, parameters);
  23. // Print some info
  24. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  25. Console.WriteLine($"Created executor with model: {name}");
  26. // Evaluate the initial prompt to create one conversation
  27. using var conversation = executor.Create();
  28. conversation.Prompt(executor.Context.Tokenize(prompt));
  29. // Create the start node wrapping the conversation
  30. var node = new Node(executor.Context);
  31. // Print the prompt
  32. Console.ForegroundColor = ConsoleColor.Green;
  33. Console.WriteLine(prompt);
  34. for (var i = 0; i < n_repeats; i++)
  35. {
  36. for (var j = 0; j < n_generate; j++)
  37. {
  38. // Run inference
  39. await executor.Infer();
  40. // Sample a token
  41. var token = node.Sample(conversation);
  42. // Continue conversation with this token
  43. if (j != n_generate - 1)
  44. conversation.Prompt(token);
  45. }
  46. // Write out what we generated
  47. node.Write(n_rewind, i + 1);
  48. // Rewind back a few tokens
  49. conversation.Rewind(n_rewind + 1);
  50. // Prompt with a token
  51. conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
  52. // Create a new node around the rewound conversation
  53. node = new Node(executor.Context);
  54. }
  55. Console.WriteLine("Press any key to exit demo");
  56. Console.ReadKey(true);
  57. }
  58. private class Node
  59. {
  60. private readonly LLamaContext _context;
  61. private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
  62. private readonly DefaultSamplingPipeline Sampler;
  63. public Node(LLamaContext context)
  64. {
  65. _context = context;
  66. Sampler = new DefaultSamplingPipeline();
  67. }
  68. public LLamaToken Sample(Conversation conversation)
  69. {
  70. var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
  71. _tokens.Add(token);
  72. return token;
  73. }
  74. public void Write(int n_rewind, int depth)
  75. {
  76. var decoder = new StreamingTokenDecoder(_context);
  77. for (var i = 0; i < _tokens.Count - n_rewind; i++)
  78. decoder.Add(_tokens[i]);
  79. AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");
  80. for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
  81. decoder.Add(_tokens[i]);
  82. AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
  83. }
  84. public LLamaToken GetToken(int index)
  85. {
  86. return _tokens[index];
  87. }
  88. }
  89. }