| @@ -1,4 +1,4 @@ | |||||
| using LLama.Abstractions; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| @@ -195,11 +195,11 @@ namespace LLama | |||||
| // if we run out of context: | // if we run out of context: | ||||
| // - take the tokensToKeep first tokens from the original prompt (via n_past) | // - take the tokensToKeep first tokens from the original prompt (via n_past) | ||||
| // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches | // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches | ||||
| var n_left = _pastTokensCount - tokensToKeep - 1; | |||||
| var n_left = _pastTokensCount - tokensToKeep; | |||||
| var n_discard = n_left / 2; | var n_discard = n_left / 2; | ||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1, tokensToKeep + 1 + n_discard); | |||||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1 + n_discard, _pastTokensCount, -n_discard); | |||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard); | |||||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard); | |||||
| _pastTokensCount -= n_discard; | _pastTokensCount -= n_discard; | ||||
| @@ -1,4 +1,4 @@ | |||||
| using LLama.Abstractions; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using System; | using System; | ||||
| @@ -186,7 +186,10 @@ namespace LLama | |||||
| _is_prompt_run = false; | _is_prompt_run = false; | ||||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | if (_pastTokensCount + _embeds.Count > Context.ContextSize) | ||||
| { | { | ||||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | |||||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||||
| // Instruct always uses input token size. | |||||
| var tokensToKeep = _embed_inps.Count; | |||||
| HandleRunOutOfContext(tokensToKeep); | |||||
| } | } | ||||
| TryReuseMatchingPrefix(); | TryReuseMatchingPrefix(); | ||||
| @@ -1,4 +1,4 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using System; | using System; | ||||
| @@ -231,7 +231,19 @@ namespace LLama | |||||
| _is_prompt_run = false; | _is_prompt_run = false; | ||||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | if (_pastTokensCount + _embeds.Count > Context.ContextSize) | ||||
| { | { | ||||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | |||||
| // number of tokens to keep when resetting context | |||||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||||
| var tokensToKeep = inferenceParams.TokensKeep; | |||||
| if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) | |||||
| { | |||||
| tokensToKeep = _embed_inps.Count; | |||||
| } | |||||
| else | |||||
| { | |||||
| tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token | |||||
| } | |||||
| HandleRunOutOfContext(tokensToKeep); | |||||
| } | } | ||||
| TryReuseMatchingPrefix(); | TryReuseMatchingPrefix(); | ||||
| @@ -1,4 +1,4 @@ | |||||
| using LLama.Abstractions; | |||||
| using LLama.Abstractions; | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| @@ -144,11 +144,25 @@ namespace LLama | |||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | ||||
| if (n_past + tokens.Count >= Context.ContextSize) | if (n_past + tokens.Count >= Context.ContextSize) | ||||
| { | { | ||||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | |||||
| var canAddBos = Context.ShouldAddBosToken(); | |||||
| var tokensKeep = inferenceParams.TokensKeep; | |||||
| // number of tokens to keep when resetting context | |||||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||||
| if (tokensKeep < 0 || tokensKeep > tokens.Count) | |||||
| { | |||||
| tokensKeep = tokens.Count; | |||||
| } | |||||
| else | |||||
| { | |||||
| tokensKeep += Convert.ToInt32(canAddBos); | |||||
| } | |||||
| var n_left = n_past - tokensKeep; | |||||
| var n_discard = n_left / 2; | var n_discard = n_left / 2; | ||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard); | |||||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard); | |||||
| n_past -= n_discard; | n_past -= n_discard; | ||||
| } | } | ||||