You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchDecoding.md 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Batch decoding
  2. ```cs
  3. using System.Diagnostics;
  4. using System.Text;
  5. using LLama.Common;
  6. using LLama.Native;
  7. using LLama.Sampling;
  8. public class BatchedDecoding
  9. {
  10. private const int n_parallel = 8;
  11. private const int n_len = 32;
  12. public static async Task Run()
  13. {
  14. Console.Write("Please input your model path: ");
  15. var modelPath = Console.ReadLine();
  16. Console.WriteLine("Prompt (leave blank to select automatically):");
  17. var prompt = Console.ReadLine();
  18. if (string.IsNullOrWhiteSpace(prompt))
  19. prompt = "Not many people know that";
  20. // Load model
  21. var parameters = new ModelParams(modelPath);
  22. using var model = LLamaWeights.LoadFromFile(parameters);
  23. // Tokenize prompt
  24. var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
  25. var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
  26. // Create a context
  27. parameters.ContextSize = (uint)model.ContextSize;
  28. parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
  29. using var context = model.CreateContext(parameters);
  30. var n_ctx = context.ContextSize;
  31. // make sure the KV cache is big enough to hold all the prompt and generated tokens
  32. if (n_kv_req > n_ctx)
  33. {
  34. await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
  35. await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
  36. return;
  37. }
  38. var batch = new LLamaBatch();
  39. // evaluate the initial prompt
  40. batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);
  41. if (await context.DecodeAsync(batch) != DecodeResult.Ok)
  42. {
  43. await Console.Error.WriteLineAsync("llama_decode failed");
  44. return;
  45. }
  46. // assign the system KV cache to all parallel sequences
  47. // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
  48. for (var i = 1; i < n_parallel; ++i)
  49. {
  50. context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
  51. }
  52. if (n_parallel > 1)
  53. {
  54. Console.WriteLine();
  55. Console.WriteLine($"generating {n_parallel} sequences...");
  56. }
  57. // remember the batch index of the last token for each parallel sequence
  58. // we need this to determine which logits to sample from
  59. List<int> i_batch = new();
  60. for (var i = 0; i < n_parallel; i++)
  61. i_batch.Add(batch.TokenCount - 1);
  62. // Create per-stream decoder and sampler
  63. var decoders = new StreamingTokenDecoder[n_parallel];
  64. var samplers = new ISamplingPipeline[n_parallel];
  65. for (var i = 0; i < n_parallel; i++)
  66. {
  67. decoders[i] = new StreamingTokenDecoder(context);
  68. samplers[i] = new DefaultSamplingPipeline
  69. {
  70. Temperature = 0.1f + (float)i / n_parallel,
  71. MinP = 0.25f,
  72. };
  73. }
  74. var n_cur = batch.TokenCount;
  75. var n_decode = 0;
  76. var timer = new Stopwatch();
  77. timer.Start();
  78. while (n_cur <= n_len)
  79. {
  80. batch.Clear();
  81. for (var i = 0; i < n_parallel; i++)
  82. {
  83. // Skip completed streams
  84. if (i_batch[i] < 0)
  85. continue;
  86. // Use the sampling pipeline to select a token
  87. var new_token_id = samplers[i].Sample(
  88. context.NativeHandle,
  89. context.NativeHandle.GetLogitsIth(i_batch[i]),
  90. Array.Empty<LLamaToken>()
  91. );
  92. // Finish this stream early if necessary
  93. if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
  94. {
  95. i_batch[i] = -1;
  96. Console.WriteLine($"Completed Stream {i} early");
  97. continue;
  98. }
  99. // Add this token to the decoder, so it will be turned into text
  100. decoders[i].Add(new_token_id);
  101. i_batch[i] = batch.TokenCount;
  102. // push this new token for next evaluation
  103. batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);
  104. n_decode++;
  105. }
  106. // Check if all streams are finished
  107. if (batch.TokenCount == 0)
  108. {
  109. break;
  110. }
  111. n_cur++;
  112. // evaluate the current batch with the transformer model
  113. if (await context.DecodeAsync(batch) != 0)
  114. {
  115. await Console.Error.WriteLineAsync("failed to eval");
  116. return;
  117. }
  118. }
  119. timer.Stop();
  120. Console.ForegroundColor = ConsoleColor.Yellow;
  121. Console.WriteLine();
  122. Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
  123. Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
  124. var index = 0;
  125. foreach (var stream in decoders)
  126. {
  127. var text = stream.Read();
  128. Console.ForegroundColor = ConsoleColor.Green;
  129. Console.Write($"{index++}. {prompt}");
  130. Console.ForegroundColor = ConsoleColor.Red;
  131. Console.WriteLine(text);
  132. }
  133. Console.WriteLine("Press any key to exit demo");
  134. Console.ReadKey(true);
  135. }
  136. }
  137. ```