You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchedDecoding.cs 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. using System.Diagnostics;
  2. using System.Security.Cryptography;
  3. using System.Text;
  4. using LLama.Common;
  5. using LLama.Native;
  6. namespace LLama.Examples.NewVersion;
  7. public class BatchedDecoding
  8. {
  9. private const int n_parallel = 8;
  10. private const int n_len = 32;
  11. public static async Task Run()
  12. {
  13. Console.Write("Please input your model path: ");
  14. //todo:var modelPath = Console.ReadLine();
  15. var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf";
  16. Console.WriteLine("Prompt (leave blank to select automatically):");
  17. var prompt = Console.ReadLine();
  18. if (string.IsNullOrWhiteSpace(prompt))
  19. prompt = "I would like to tell you about";
  20. // Load model
  21. var parameters = new ModelParams(modelPath);
  22. using var model = LLamaWeights.LoadFromFile(parameters);
  23. // Tokenize prompt
  24. var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
  25. var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
  26. // Create a context
  27. parameters.ContextSize = (uint)model.ContextSize;
  28. parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue));
  29. parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
  30. using var context = model.CreateContext(parameters);
  31. var n_ctx = context.ContextSize;
  32. // make sure the KV cache is big enough to hold all the prompt and generated tokens
  33. if (n_kv_req > n_ctx)
  34. {
  35. await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
  36. await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
  37. return;
  38. }
  39. using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
  40. // evaluate the initial prompt
  41. for (var i = 0; i < prompt_tokens.Length; i++)
  42. llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false);
  43. Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length);
  44. // llama_decode will output logits only for the last token of the prompt
  45. unsafe
  46. {
  47. batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
  48. }
  49. if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
  50. {
  51. await Console.Error.WriteLineAsync("llama_decode failed");
  52. return;
  53. }
  54. // assign the system KV cache to all parallel sequences
  55. // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
  56. for (var i = 1; i < n_parallel; ++i)
  57. {
  58. NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
  59. }
  60. if (n_parallel > 1)
  61. {
  62. Console.WriteLine();
  63. Console.WriteLine($"generating {n_parallel} sequences...");
  64. }
  65. // remember the batch index of the last token for each parallel sequence
  66. // we need this to determine which logits to sample from
  67. List<int> i_batch = new();
  68. for (var i = 0; i < n_parallel; i++)
  69. i_batch.Add(batch.NativeBatch.n_tokens - 1);
  70. int n_cur = batch.NativeBatch.n_tokens;
  71. int n_decode = 0;
  72. var streams = new List<int>[n_parallel];
  73. for (var i = 0; i < n_parallel; i++)
  74. streams[i] = new();
  75. var eos = NativeApi.llama_token_eos(model.NativeHandle);
  76. var nl = NativeApi.llama_token_nl(model.NativeHandle);
  77. var timer = new Stopwatch();
  78. timer.Start();
  79. while (n_cur <= n_len)
  80. {
  81. llama_batch_clear(batch);
  82. for (var i = 0; i < n_parallel; i++)
  83. {
  84. // Skip completed streams
  85. if (i_batch[i] < 0)
  86. continue;
  87. unsafe
  88. {
  89. var n_vocab = model.VocabCount;
  90. var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]);
  91. var candidates = new LLamaTokenData[n_vocab];
  92. for (var token_id = 0; token_id < n_vocab; token_id++)
  93. {
  94. candidates[token_id] = new LLamaTokenData
  95. {
  96. id = token_id,
  97. logit = logits[token_id]
  98. };
  99. }
  100. var candidates_p = new LLamaTokenDataArray(candidates);
  101. using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native);
  102. const int top_k = 40;
  103. const float top_p = 0.9f;
  104. const float temp = 0.4f;
  105. NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1);
  106. NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1);
  107. NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp);
  108. var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native);
  109. if (new_token_id == eos || new_token_id == nl)
  110. {
  111. i_batch[i] = -1;
  112. Console.WriteLine($"Completed Stream {i} early");
  113. continue;
  114. }
  115. streams[i].Add(new_token_id);
  116. i_batch[i] = batch.NativeBatch.n_tokens;
  117. // push this new token for next evaluation
  118. llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
  119. n_decode++;
  120. }
  121. }
  122. // all streams are finished
  123. if (batch.NativeBatch.n_tokens == 0)
  124. {
  125. break;
  126. }
  127. n_cur++;
  128. // evaluate the current batch with the transformer model
  129. if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
  130. {
  131. await Console.Error.WriteLineAsync("failed to eval");
  132. return;
  133. }
  134. }
  135. timer.Stop();
  136. Console.ForegroundColor = ConsoleColor.Yellow;
  137. Console.WriteLine();
  138. Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
  139. Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
  140. var index = 0;
  141. foreach (var stream in streams)
  142. {
  143. var text = context.DeTokenize(stream);
  144. Console.ForegroundColor = ConsoleColor.Green;
  145. Console.Write($"{index++}. {prompt}");
  146. Console.ForegroundColor = ConsoleColor.Red;
  147. Console.WriteLine(text);
  148. }
  149. }
  150. /// <summary>
  151. /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
  152. /// </summary>
  153. private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
  154. {
  155. unsafe
  156. {
  157. ref var batch = ref batchHandle.NativeBatch;
  158. batch.token[batch.n_tokens] = token;
  159. batch.pos[batch.n_tokens] = pos;
  160. batch.n_seq_id[batch.n_tokens] = sequences.Count;
  161. for (var i = 0; i < sequences.Count; i++)
  162. batch.seq_id[batch.n_tokens][i] = sequences[i];
  163. batch.logits[batch.n_tokens] = Convert.ToByte(logits);
  164. batch.n_tokens++;
  165. }
  166. }
  167. /// <summary>
  168. /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
  169. /// </summary>
  170. /// <param name="batchHandle"></param>
  171. private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle)
  172. {
  173. batchHandle.NativeBatch.n_tokens = 0;
  174. }
  175. }