Browse Source

Refactor executors

pull/714/head
ksanchez 1 year ago
parent
commit
0bbbf171ed
4 changed files with 41 additions and 12 deletions
  1. +4
    -4
      LLama/LLamaExecutorBase.cs
  2. +5
    -2
      LLama/LLamaInstructExecutor.cs
  3. +14
    -2
      LLama/LLamaInteractExecutor.cs
  4. +18
    -4
      LLama/LLamaStatelessExecutor.cs

+ 4
- 4
LLama/LLamaExecutorBase.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Native; using LLama.Native;
@@ -195,11 +195,11 @@ namespace LLama
// if we run out of context: // if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past) // - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
var n_left = _pastTokensCount - tokensToKeep - 1;
var n_left = _pastTokensCount - tokensToKeep;
var n_discard = n_left / 2; var n_discard = n_left / 2;


NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1, tokensToKeep + 1 + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + 1 + n_discard, _pastTokensCount, -n_discard);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard);


_pastTokensCount -= n_discard; _pastTokensCount -= n_discard;


+ 5
- 2
LLama/LLamaInstructExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using LLama.Native; using LLama.Native;
using System; using System;
@@ -186,7 +186,10 @@ namespace LLama
_is_prompt_run = false; _is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize) if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{ {
HandleRunOutOfContext(inferenceParams.TokensKeep);
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
// Instruct always uses input token size.
var tokensToKeep = _embed_inps.Count;
HandleRunOutOfContext(tokensToKeep);
} }


TryReuseMatchingPrefix(); TryReuseMatchingPrefix();


+ 14
- 2
LLama/LLamaInteractExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Common;
using LLama.Common;
using LLama.Native; using LLama.Native;
using LLama.Abstractions; using LLama.Abstractions;
using System; using System;
@@ -231,7 +231,19 @@ namespace LLama
_is_prompt_run = false; _is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize) if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{ {
HandleRunOutOfContext(inferenceParams.TokensKeep);
// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
var tokensToKeep = inferenceParams.TokensKeep;
if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count)
{
tokensToKeep = _embed_inps.Count;
}
else
{
tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token
}

HandleRunOutOfContext(tokensToKeep);
} }


TryReuseMatchingPrefix(); TryReuseMatchingPrefix();


+ 18
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@@ -144,11 +144,25 @@ namespace LLama
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize) if (n_past + tokens.Count >= Context.ContextSize)
{ {
var n_left = n_past - inferenceParams.TokensKeep - 1;
var canAddBos = Context.ShouldAddBosToken();
var tokensKeep = inferenceParams.TokensKeep;

// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
if (tokensKeep < 0 || tokensKeep > tokens.Count)
{
tokensKeep = tokens.Count;
}
else
{
tokensKeep += Convert.ToInt32(canAddBos);
}

var n_left = n_past - tokensKeep;
var n_discard = n_left / 2; var n_discard = n_left / 2;


NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard);


n_past -= n_discard; n_past -= n_discard;
} }


Loading…
Cancel
Save