Browse Source

Merge pull request #714 from ksanman/infinite-context

Implement context shifting in executor base
pull/717/head
Martin Evans GitHub 1 year ago
parent
commit
9906871f84
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 16 deletions
  1. +12
    -1
      LLama/LLamaContext.cs
  2. +7
    -6
      LLama/LLamaExecutorBase.cs
  3. +5
    -2
      LLama/LLamaInstructExecutor.cs
  4. +14
    -2
      LLama/LLamaInteractExecutor.cs
  5. +18
    -4
      LLama/LLamaStatelessExecutor.cs
  6. +3
    -1
      LLama/Native/SafeLLamaContextHandle.cs

+ 12
- 1
LLama/LLamaContext.cs View File

@@ -1,4 +1,4 @@
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
@@ -521,6 +521,17 @@ namespace LLama
return candidates_p;
}

/// <summary>
/// Gets whether or not the Bos token should be added.
/// From common.cpp https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/common/common.cpp#L2417
/// </summary>
/// <returns></returns>
public bool ShouldAddBosToken()
{
var addBos = NativeApi.llama_add_bos_token(NativeHandle.ModelHandle);
return addBos != -1 ? Convert.ToBoolean(addBos) : NativeHandle.LLamaVocabType == LLamaVocabType.SentencePiece;
}

#region eval overloads
/// <summary>
/// </summary>


+ 7
- 6
LLama/LLamaExecutorBase.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
@@ -195,13 +195,14 @@ namespace LLama
// if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
int n_left = _pastTokensCount - tokensToKeep;
var n_left = _pastTokensCount - tokensToKeep;
var n_discard = n_left / 2;

_pastTokensCount = Math.Max(1, tokensToKeep);

// insert n_left/2 tokens at the start of embed from last_n_tokens
_embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip((int)Context.ContextSize - n_left / 2 - _embeds.Count));
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard);

_pastTokensCount -= n_discard;
// stop saving session if we run out of context
_pathSession = string.Empty;
}


+ 5
- 2
LLama/LLamaInstructExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
using System;
@@ -186,7 +186,10 @@ namespace LLama
_is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{
HandleRunOutOfContext(inferenceParams.TokensKeep);
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
// Instruct always uses input token size.
var tokensToKeep = _embed_inps.Count;
HandleRunOutOfContext(tokensToKeep);
}

TryReuseMatchingPrefix();


+ 14
- 2
LLama/LLamaInteractExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Common;
using LLama.Common;
using LLama.Native;
using LLama.Abstractions;
using System;
@@ -231,7 +231,19 @@ namespace LLama
_is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{
HandleRunOutOfContext(inferenceParams.TokensKeep);
// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
var tokensToKeep = inferenceParams.TokensKeep;
if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count)
{
tokensToKeep = _embed_inps.Count;
}
else
{
tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token
}

HandleRunOutOfContext(tokensToKeep);
}

TryReuseMatchingPrefix();


+ 18
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
@@ -144,11 +144,25 @@ namespace LLama
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var canAddBos = Context.ShouldAddBosToken();
var tokensKeep = inferenceParams.TokensKeep;

// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
if (tokensKeep < 0 || tokensKeep > tokens.Count)
{
tokensKeep = tokens.Count;
}
else
{
tokensKeep += Convert.ToInt32(canAddBos);
}

var n_left = n_past - tokensKeep;
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard);

n_past -= n_discard;
}


+ 3
- 1
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
@@ -19,6 +19,8 @@ namespace LLama.Native
/// </summary>
public int VocabCount => ThrowIfDisposed().VocabCount;

public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType;

/// <summary>
/// Total number of tokens in the context
/// </summary>


Loading…
Cancel
Save