You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedExecutorRewind.cs 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using LLama.Batched;
  2. using LLama.Common;
  3. using LLama.Native;
  4. using LLama.Sampling;
  5. namespace LLama.Examples.Examples;
  6. /// <summary>
  7. /// This demonstrates generating tokens and then rewinding to an earlier state
  8. /// </summary>
  9. public class BatchedExecutorRewind
  10. {
  11. private const int n_generate = 24;
  12. private const int n_rewind = 12;
  13. private const int n_repeats = 6;
  14. public static async Task Run()
  15. {
  16. string modelPath = UserSettings.GetModelPath();
  17. var parameters = new ModelParams(modelPath);
  18. using var model = LLamaWeights.LoadFromFile(parameters);
  19. Console.WriteLine("Prompt (leave blank to select automatically):");
  20. var prompt = Console.ReadLine();
  21. if (string.IsNullOrWhiteSpace(prompt))
  22. prompt = "Not many people know that";
  23. // Create an executor that can evaluate a batch of conversations together
  24. var executor = new BatchedExecutor(model, parameters);
  25. // Print some info
  26. var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
  27. Console.WriteLine($"Created executor with model: {name}");
  28. // Evaluate the initial prompt to create one conversation
  29. var conversation = executor.Prompt(prompt);
  30. // Create the start node wrapping the conversation
  31. var node = new Node(executor.Context);
  32. // Print the prompt
  33. Console.ForegroundColor = ConsoleColor.Green;
  34. Console.WriteLine(prompt);
  35. for (var i = 0; i < n_repeats; i++)
  36. {
  37. for (var j = 0; j < n_generate; j++)
  38. {
  39. // Run inference
  40. await executor.Infer();
  41. // Sample a token
  42. var token = node.Sample(conversation);
  43. // Continue conversation with this token
  44. if (j != n_generate - 1)
  45. conversation.Prompt(token);
  46. }
  47. // Write out what we generated
  48. node.Write(n_rewind, i + 1);
  49. // Rewind back a few tokens
  50. conversation.Rewind(n_rewind + 1);
  51. // Prompt with a token
  52. conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
  53. // Create a new node around the rewound conversation
  54. node = new Node(executor.Context);
  55. }
  56. Console.WriteLine("Press any key to exit demo");
  57. Console.ReadKey(true);
  58. }
  59. private class Node
  60. {
  61. private readonly LLamaContext _context;
  62. private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
  63. private readonly DefaultSamplingPipeline Sampler;
  64. public Node(LLamaContext context)
  65. {
  66. _context = context;
  67. Sampler = new DefaultSamplingPipeline();
  68. }
  69. public LLamaToken Sample(Conversation conversation)
  70. {
  71. var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>());
  72. _tokens.Add(token);
  73. return token;
  74. }
  75. public void Write(int n_rewind, int depth)
  76. {
  77. var decoder = new StreamingTokenDecoder(_context);
  78. for (var i = 0; i < _tokens.Count - n_rewind; i++)
  79. decoder.Add(_tokens[i]);
  80. Console.ForegroundColor = ConsoleColor.Green;
  81. Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" "));
  82. for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
  83. decoder.Add(_tokens[i]);
  84. Console.ForegroundColor = ConsoleColor.DarkRed;
  85. Console.WriteLine(decoder.Read().ReplaceLineEndings(" "));
  86. }
  87. public LLamaToken GetToken(int index)
  88. {
  89. return _tokens[index];
  90. }
  91. }
  92. }