You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedDecoding.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using System.Diagnostics;
  2. using System.Text;
  3. using LLama.Common;
  4. using LLama.Native;
  5. namespace LLama.Examples.Examples;
  6. /// <summary>
  7. /// This demonstrates generating multiple replies to the same prompt, with a shared cache
  8. /// </summary>
  9. /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
  10. public class BatchedDecoding
  11. {
  12. private const int n_parallel = 8;
  13. private const int n_len = 32;
  14. private const int top_k = 80;
  15. private const float top_p = 0.8f;
  16. private const float temp = 0.75f;
  17. public static async Task Run()
  18. {
  19. Console.Write("Please input your model path: ");
  20. var modelPath = Console.ReadLine();
  21. Console.WriteLine("Prompt (leave blank to select automatically):");
  22. var prompt = Console.ReadLine();
  23. if (string.IsNullOrWhiteSpace(prompt))
  24. prompt = "Not many people know that";
  25. // Load model
  26. var parameters = new ModelParams(modelPath);
  27. using var model = LLamaWeights.LoadFromFile(parameters);
  28. // Tokenize prompt
  29. var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
  30. var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
  31. // Create a context
  32. parameters.ContextSize = (uint)model.ContextSize;
  33. parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
  34. using var context = model.CreateContext(parameters);
  35. var n_ctx = context.ContextSize;
  36. // make sure the KV cache is big enough to hold all the prompt and generated tokens
  37. if (n_kv_req > n_ctx)
  38. {
  39. await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
  40. await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
  41. return;
  42. }
  43. var batch = new LLamaBatch();
  44. // evaluate the initial prompt
  45. for (var i = 0; i < prompt_tokens.Length; i++)
  46. batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
  47. if (await context.DecodeAsync(batch) != 0)
  48. {
  49. await Console.Error.WriteLineAsync("llama_decode failed");
  50. return;
  51. }
  52. // assign the system KV cache to all parallel sequences
  53. // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
  54. for (var i = 1; i < n_parallel; ++i)
  55. {
  56. NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
  57. }
  58. if (n_parallel > 1)
  59. {
  60. Console.WriteLine();
  61. Console.WriteLine($"generating {n_parallel} sequences...");
  62. }
  63. // remember the batch index of the last token for each parallel sequence
  64. // we need this to determine which logits to sample from
  65. List<int> i_batch = new();
  66. for (var i = 0; i < n_parallel; i++)
  67. i_batch.Add(batch.TokenCount - 1);
  68. var n_cur = batch.TokenCount;
  69. var n_decode = 0;
  70. var streams = new StreamingTokenDecoder[n_parallel];
  71. for (var i = 0; i < n_parallel; i++)
  72. streams[i] = new StreamingTokenDecoder(context);
  73. var eos = model.EndOfSentenceToken;
  74. var nl = model.NewlineToken;
  75. var timer = new Stopwatch();
  76. timer.Start();
  77. while (n_cur <= n_len)
  78. {
  79. batch.Clear();
  80. for (var i = 0; i < n_parallel; i++)
  81. {
  82. // Skip completed streams
  83. if (i_batch[i] < 0)
  84. continue;
  85. var n_vocab = model.VocabCount;
  86. LLamaTokenDataArray candidates;
  87. unsafe
  88. {
  89. candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
  90. }
  91. candidates.TopK(context.NativeHandle, top_k);
  92. candidates.TopP(context.NativeHandle, top_p);
  93. candidates.Temperature(context.NativeHandle, temp);
  94. var new_token_id = candidates.SampleToken(context.NativeHandle);
  95. if (new_token_id == eos || new_token_id == nl)
  96. {
  97. i_batch[i] = -1;
  98. Console.WriteLine($"Completed Stream {i} early");
  99. continue;
  100. }
  101. streams[i].Add(new_token_id);
  102. i_batch[i] = batch.TokenCount;
  103. // push this new token for next evaluation
  104. batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
  105. n_decode++;
  106. }
  107. // all streams are finished
  108. if (batch.TokenCount == 0)
  109. {
  110. break;
  111. }
  112. n_cur++;
  113. // evaluate the current batch with the transformer model
  114. if (await context.DecodeAsync(batch) != 0)
  115. {
  116. await Console.Error.WriteLineAsync("failed to eval");
  117. return;
  118. }
  119. }
  120. timer.Stop();
  121. Console.ForegroundColor = ConsoleColor.Yellow;
  122. Console.WriteLine();
  123. Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
  124. Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
  125. var index = 0;
  126. foreach (var stream in streams)
  127. {
  128. var text = stream.Read();
  129. Console.ForegroundColor = ConsoleColor.Green;
  130. Console.Write($"{index++}. {prompt}");
  131. Console.ForegroundColor = ConsoleColor.Red;
  132. Console.WriteLine(text);
  133. }
  134. Console.WriteLine("Press any key to exit demo");
  135. Console.ReadKey(true);
  136. }
  137. }