| @@ -8,6 +8,7 @@ | |||||
| <Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
| <!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults --> | <!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults --> | ||||
| <IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes> | <IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes> | ||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -1,25 +0,0 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| namespace LLama.Examples.NewVersion; | |||||
| public class BatchedBench | |||||
| { | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| //todo:var modelPath = Console.ReadLine(); | |||||
| var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf"; | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| parameters.ContextSize = (uint)model.ContextSize; | |||||
| using var context = model.CreateContext(parameters); | |||||
| var n_kv_max = 1024; | |||||
| using var batch = LLamaBatchSafeHandle.Create(n_kv_max, 0, 1); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,219 @@ | |||||
| using System.Diagnostics; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| namespace LLama.Examples.NewVersion; | |||||
| public class BatchedDecoding | |||||
| { | |||||
| private const int n_parallel = 8; | |||||
| private const int n_len = 32; | |||||
| public static async Task Run() | |||||
| { | |||||
| Console.Write("Please input your model path: "); | |||||
| //todo:var modelPath = Console.ReadLine(); | |||||
| var modelPath = @"C:\Users\Martin\Documents\Python\oobabooga_windows\text-generation-webui\models\llama-2-7b-chat.Q5_K_M.gguf"; | |||||
| Console.WriteLine("Prompt (leave blank to select automatically):"); | |||||
| var prompt = Console.ReadLine(); | |||||
| if (string.IsNullOrWhiteSpace(prompt)) | |||||
| prompt = "I would like to tell you about"; | |||||
| // Load model | |||||
| var parameters = new ModelParams(modelPath); | |||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| // Tokenize prompt | |||||
| var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); | |||||
| var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; | |||||
| // Create a context | |||||
| parameters.ContextSize = (uint)model.ContextSize; | |||||
| parameters.Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MinValue, int.MaxValue)); | |||||
| parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); | |||||
| using var context = model.CreateContext(parameters); | |||||
| var n_ctx = context.ContextSize; | |||||
| // make sure the KV cache is big enough to hold all the prompt and generated tokens | |||||
| if (n_kv_req > n_ctx) | |||||
| { | |||||
| await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); | |||||
| await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); | |||||
| return; | |||||
| } | |||||
| using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); | |||||
| // evaluate the initial prompt | |||||
| for (var i = 0; i < prompt_tokens.Length; i++) | |||||
| llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false); | |||||
| Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length); | |||||
| // llama_decode will output logits only for the last token of the prompt | |||||
| unsafe | |||||
| { | |||||
| batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; | |||||
| } | |||||
| if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) | |||||
| { | |||||
| await Console.Error.WriteLineAsync("llama_decode failed"); | |||||
| return; | |||||
| } | |||||
| // assign the system KV cache to all parallel sequences | |||||
| // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them | |||||
| for (var i = 1; i < n_parallel; ++i) | |||||
| { | |||||
| NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); | |||||
| } | |||||
| if (n_parallel > 1) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| Console.WriteLine($"generating {n_parallel} sequences..."); | |||||
| } | |||||
| // remember the batch index of the last token for each parallel sequence | |||||
| // we need this to determine which logits to sample from | |||||
| List<int> i_batch = new(); | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| i_batch.Add(batch.NativeBatch.n_tokens - 1); | |||||
| int n_cur = batch.NativeBatch.n_tokens; | |||||
| int n_decode = 0; | |||||
| var streams = new List<int>[n_parallel]; | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| streams[i] = new(); | |||||
| var eos = NativeApi.llama_token_eos(model.NativeHandle); | |||||
| var nl = NativeApi.llama_token_nl(model.NativeHandle); | |||||
| var timer = new Stopwatch(); | |||||
| timer.Start(); | |||||
| while (n_cur <= n_len) | |||||
| { | |||||
| llama_batch_clear(batch); | |||||
| for (var i = 0; i < n_parallel; i++) | |||||
| { | |||||
| // Skip completed streams | |||||
| if (i_batch[i] < 0) | |||||
| continue; | |||||
| unsafe | |||||
| { | |||||
| var n_vocab = model.VocabCount; | |||||
| var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]); | |||||
| var candidates = new LLamaTokenData[n_vocab]; | |||||
| for (var token_id = 0; token_id < n_vocab; token_id++) | |||||
| { | |||||
| candidates[token_id] = new LLamaTokenData | |||||
| { | |||||
| id = token_id, | |||||
| logit = logits[token_id] | |||||
| }; | |||||
| } | |||||
| var candidates_p = new LLamaTokenDataArray(candidates); | |||||
| using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native); | |||||
| const int top_k = 40; | |||||
| const float top_p = 0.9f; | |||||
| const float temp = 0.4f; | |||||
| NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1); | |||||
| NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1); | |||||
| NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp); | |||||
| var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native); | |||||
| if (new_token_id == eos || new_token_id == nl) | |||||
| { | |||||
| i_batch[i] = -1; | |||||
| Console.WriteLine($"Completed Stream {i} early"); | |||||
| continue; | |||||
| } | |||||
| streams[i].Add(new_token_id); | |||||
| i_batch[i] = batch.NativeBatch.n_tokens; | |||||
| // push this new token for next evaluation | |||||
| llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true); | |||||
| n_decode++; | |||||
| } | |||||
| } | |||||
| // all streams are finished | |||||
| if (batch.NativeBatch.n_tokens == 0) | |||||
| { | |||||
| break; | |||||
| } | |||||
| n_cur++; | |||||
| // evaluate the current batch with the transformer model | |||||
| if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) | |||||
| { | |||||
| await Console.Error.WriteLineAsync("failed to eval"); | |||||
| return; | |||||
| } | |||||
| } | |||||
| timer.Stop(); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||||
| Console.WriteLine(); | |||||
| Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); | |||||
| Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); | |||||
| var index = 0; | |||||
| foreach (var stream in streams) | |||||
| { | |||||
| var text = context.DeTokenize(stream); | |||||
| Console.ForegroundColor = ConsoleColor.Green; | |||||
| Console.Write($"{index++}. {prompt}"); | |||||
| Console.ForegroundColor = ConsoleColor.Red; | |||||
| Console.WriteLine(text); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||||
| /// </summary> | |||||
| private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| ref var batch = ref batchHandle.NativeBatch; | |||||
| batch.token[batch.n_tokens] = token; | |||||
| batch.pos[batch.n_tokens] = pos; | |||||
| batch.n_seq_id[batch.n_tokens] = sequences.Count; | |||||
| for (var i = 0; i < sequences.Count; i++) | |||||
| batch.seq_id[batch.n_tokens][i] = sequences[i]; | |||||
| batch.logits[batch.n_tokens] = Convert.ToByte(logits); | |||||
| batch.n_tokens++; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 | |||||
| /// </summary> | |||||
| /// <param name="batchHandle"></param> | |||||
| private static void llama_batch_clear(LLamaBatchSafeHandle batchHandle) | |||||
| { | |||||
| batchHandle.NativeBatch.n_tokens = 0; | |||||
| } | |||||
| } | |||||
| @@ -91,7 +91,7 @@ | |||||
| } | } | ||||
| else if (choice == 15) | else if (choice == 15) | ||||
| { | { | ||||
| await BatchedBench.Run(); | |||||
| await BatchedDecoding.Run(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -305,7 +305,7 @@ namespace LLama | |||||
| } | } | ||||
| // Save the newline logit value | // Save the newline logit value | ||||
| var nl_token = NativeApi.llama_token_nl(NativeHandle); | |||||
| var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); | |||||
| var nl_logit = logits[nl_token]; | var nl_logit = logits[nl_token]; | ||||
| // Convert logits into token candidates | // Convert logits into token candidates | ||||
| @@ -163,7 +163,7 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||||
| { | { | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace LLama | |||||
| public InteractiveExecutor(LLamaContext context, ILogger? logger = null) | public InteractiveExecutor(LLamaContext context, ILogger? logger = null) | ||||
| : base(context, logger) | : base(context, logger) | ||||
| { | { | ||||
| _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); | |||||
| _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -141,7 +141,7 @@ namespace LLama | |||||
| return (true, Array.Empty<string>()); | return (true, Array.Empty<string>()); | ||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||||
| { | { | ||||
| return (true, new[] { " [end of text]\n" }); | return (true, new[] { " [end of text]\n" }); | ||||
| } | } | ||||
| @@ -202,7 +202,7 @@ namespace LLama | |||||
| _last_n_tokens.Enqueue(id); | _last_n_tokens.Enqueue(id); | ||||
| if (id == NativeApi.llama_token_eos(Context.NativeHandle)) | |||||
| if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) | |||||
| { | { | ||||
| id = _llama_token_newline; | id = _llama_token_newline; | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | ||||
| @@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the native llama_batch struct | /// Get the native llama_batch struct | ||||
| /// </summary> | /// </summary> | ||||
| public LLamaNativeBatch NativeBatch { get; private set; } | |||||
| public LLamaNativeBatch NativeBatch; | |||||
| /// <summary> | /// <summary> | ||||
| /// the token ids of the input (used when embd is NULL) | /// the token ids of the input (used when embd is NULL) | ||||
| @@ -11,12 +11,12 @@ using llama_token = Int32; | |||||
| /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens | ||||
| /// </summary> | /// </summary> | ||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public readonly unsafe struct LLamaNativeBatch | |||||
| public unsafe struct LLamaNativeBatch | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The number of items pointed at by pos, seq_id and logits. | /// The number of items pointed at by pos, seq_id and logits. | ||||
| /// </summary> | /// </summary> | ||||
| public readonly int n_tokens; | |||||
| public int n_tokens; | |||||
| /// <summary> | /// <summary> | ||||
| /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created | ||||
| @@ -33,13 +33,25 @@ public readonly unsafe struct LLamaNativeBatch | |||||
| /// </summary> | /// </summary> | ||||
| public readonly LLamaPos* pos; | public readonly LLamaPos* pos; | ||||
| /// <summary> | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ??? | |||||
| /// </summary> | |||||
| public readonly int* n_seq_id; | |||||
| /// <summary> | /// <summary> | ||||
| /// the sequence to which the respective token belongs | /// the sequence to which the respective token belongs | ||||
| /// </summary> | /// </summary> | ||||
| public readonly LLamaSeqId* seq_id; | |||||
| public readonly LLamaSeqId** seq_id; | |||||
| /// <summary> | /// <summary> | ||||
| /// if zero, the logits for the respective token will not be output | /// if zero, the logits for the respective token will not be output | ||||
| /// </summary> | /// </summary> | ||||
| public readonly byte* logits; | public readonly byte* logits; | ||||
| // Note from llama.cpp: | |||||
| // > helpers for smooth API transition - can be deprecated in the future | |||||
| // > for future-proof code, use the above fields instead and ignore everything below | |||||
| private LLamaPos _all_pos_0; | |||||
| private LLamaPos _all_pos_1; | |||||
| private LLamaSeqId _all_seq_id; | |||||
| } | } | ||||
| @@ -350,7 +350,7 @@ namespace LLama.Native | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); | |||||
| public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the embeddings for the input | /// Get the embeddings for the input | ||||
| @@ -366,21 +366,21 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); | |||||
| public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the "End of sentence" token | /// Get the "End of sentence" token | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); | |||||
| public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the "new line" token | /// Get the "new line" token | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); | |||||
| public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Print out timing information for this context | /// Print out timing information for this context | ||||