|
|
|
@@ -6,6 +6,7 @@ using System.Linq; |
|
|
|
using System.Runtime.CompilerServices; |
|
|
|
using System.Threading; |
|
|
|
using LLama.Extensions; |
|
|
|
using LLama.Native; |
|
|
|
|
|
|
|
namespace LLama |
|
|
|
{ |
|
|
|
@@ -115,12 +116,16 @@ namespace LLama |
|
|
|
break; |
|
|
|
|
|
|
|
// when run out of context |
|
|
|
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 |
|
|
|
if (n_past + tokens.Count > Context.ContextSize) |
|
|
|
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 |
|
|
|
if (n_past + tokens.Count >= Context.ContextSize) |
|
|
|
{ |
|
|
|
var n_left = n_past - inferenceParams.TokensKeep; |
|
|
|
var n_left = n_past - inferenceParams.TokensKeep - 1; |
|
|
|
var n_discard = n_left / 2; |
|
|
|
|
|
|
|
n_past = Math.Max(1, inferenceParams.TokensKeep); |
|
|
|
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); |
|
|
|
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); |
|
|
|
|
|
|
|
n_past -= n_discard; |
|
|
|
|
|
|
|
tokens.Clear(); |
|
|
|
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); |
|
|
|
|