| @@ -1,4 +1,4 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| @@ -195,11 +195,11 @@ namespace LLama | |||
| // if we run out of context: | |||
| // - take the tokensToKeep first tokens from the original prompt (via n_past) | |||
| // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches | |||
| var n_left = _pastTokensCount - tokensToKeep - 1; | |||
| var n_left = _pastTokensCount - tokensToKeep; | |||
| var n_discard = n_left / 2; | |||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1, tokensToKeep + 1 + n_discard); | |||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1 + n_discard, _pastTokensCount, -n_discard); | |||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard); | |||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard); | |||
| _pastTokensCount -= n_discard; | |||
| @@ -1,4 +1,4 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using System; | |||
| @@ -186,7 +186,10 @@ namespace LLama | |||
| _is_prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | |||
| { | |||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | |||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||
| // Instruct always uses input token size. | |||
| var tokensToKeep = _embed_inps.Count; | |||
| HandleRunOutOfContext(tokensToKeep); | |||
| } | |||
| TryReuseMatchingPrefix(); | |||
| @@ -1,4 +1,4 @@ | |||
| using LLama.Common; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using LLama.Abstractions; | |||
| using System; | |||
| @@ -231,7 +231,19 @@ namespace LLama | |||
| _is_prompt_run = false; | |||
| if (_pastTokensCount + _embeds.Count > Context.ContextSize) | |||
| { | |||
| HandleRunOutOfContext(inferenceParams.TokensKeep); | |||
| // number of tokens to keep when resetting context | |||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||
| var tokensToKeep = inferenceParams.TokensKeep; | |||
| if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) | |||
| { | |||
| tokensToKeep = _embed_inps.Count; | |||
| } | |||
| else | |||
| { | |||
| tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token | |||
| } | |||
| HandleRunOutOfContext(tokensToKeep); | |||
| } | |||
| TryReuseMatchingPrefix(); | |||
| @@ -1,4 +1,4 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| @@ -144,11 +144,25 @@ namespace LLama | |||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | |||
| if (n_past + tokens.Count >= Context.ContextSize) | |||
| { | |||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | |||
| var canAddBos = Context.ShouldAddBosToken(); | |||
| var tokensKeep = inferenceParams.TokensKeep; | |||
| // number of tokens to keep when resetting context | |||
| // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 | |||
| if (tokensKeep < 0 || tokensKeep > tokens.Count) | |||
| { | |||
| tokensKeep = tokens.Count; | |||
| } | |||
| else | |||
| { | |||
| tokensKeep += Convert.ToInt32(canAddBos); | |||
| } | |||
| var n_left = n_past - tokensKeep; | |||
| var n_discard = n_left / 2; | |||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard); | |||
| NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard); | |||
| n_past -= n_discard; | |||
| } | |||