From 0bbbf171ede7fed8121d47bf817b06817ffe4bdc Mon Sep 17 00:00:00 2001 From: ksanchez Date: Thu, 2 May 2024 23:30:16 -0600 Subject: [PATCH] Refactor executors --- LLama/LLamaExecutorBase.cs | 8 ++++---- LLama/LLamaInstructExecutor.cs | 7 +++++-- LLama/LLamaInteractExecutor.cs | 16 ++++++++++++++-- LLama/LLamaStatelessExecutor.cs | 22 ++++++++++++++++++---- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e8b2ead5..9700eb0e 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using LLama.Exceptions; using LLama.Native; @@ -195,11 +195,11 @@ namespace LLama // if we run out of context: // - take the tokensToKeep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches - var n_left = _pastTokensCount - tokensToKeep - 1; + var n_left = _pastTokensCount - tokensToKeep; var n_discard = n_left / 2; - NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1, tokensToKeep + 1 + n_discard); - NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1 + n_discard, _pastTokensCount, -n_discard); + NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard); + NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard); _pastTokensCount -= n_discard; diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 65d2d6c7..5b253096 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using LLama.Native; using System; @@ -186,7 +186,10 @@ namespace LLama _is_prompt_run = false; if (_pastTokensCount + _embeds.Count > Context.ContextSize) { - HandleRunOutOfContext(inferenceParams.TokensKeep); + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + // Instruct always uses input token size. + var tokensToKeep = _embed_inps.Count; + HandleRunOutOfContext(tokensToKeep); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index fec4f9c4..226b18ef 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Common; +using LLama.Common; using LLama.Native; using LLama.Abstractions; using System; @@ -231,7 +231,19 @@ namespace LLama _is_prompt_run = false; if (_pastTokensCount + _embeds.Count > Context.ContextSize) { - HandleRunOutOfContext(inferenceParams.TokensKeep); + // number of tokens to keep when resetting context + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + var tokensToKeep = inferenceParams.TokensKeep; + if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) + { + tokensToKeep = _embed_inps.Count; + } + else + { + tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token + } + + HandleRunOutOfContext(tokensToKeep); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 39d74f90..ab5f4146 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using System; using System.Collections.Generic; @@ -144,11 +144,25 @@ namespace LLama // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 if (n_past + tokens.Count >= Context.ContextSize) { - var n_left = n_past - inferenceParams.TokensKeep - 1; + var canAddBos = Context.ShouldAddBosToken(); + var tokensKeep = inferenceParams.TokensKeep; + + // number of tokens to keep when resetting context + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + if (tokensKeep < 0 || tokensKeep > tokens.Count) + { + tokensKeep = tokens.Count; + } + else + { + tokensKeep += Convert.ToInt32(canAddBos); + } + + var n_left = n_past - tokensKeep; var n_discard = n_left / 2; - NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); - NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); + NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard); + NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard); n_past -= n_discard; }