Browse Source

Fixed out-of-context handling in stateless executor

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
0d40338692
6 changed files with 18 additions and 9 deletions
  1. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  2. +1
    -1
      LLama/LLamaContext.cs
  3. +9
    -4
      LLama/LLamaStatelessExecutor.cs
  4. +1
    -1
      LLama/Native/LLamaPos.cs
  5. +2
    -0
      LLama/Native/NativeApi.cs
  6. +4
    -2
      LLama/Native/SafeLLamaContextHandle.cs

+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -54,7 +54,7 @@ namespace LLama.Unittest
// with a modified context // with a modified context
var @params = new InferenceParams() var @params = new InferenceParams()
{ {
MaxTokens = 100,
MaxTokens = 80,
TokensKeep = question.Length, TokensKeep = question.Length,
}; };




+ 1
- 1
LLama/LLamaContext.cs View File

@@ -382,7 +382,7 @@ namespace LLama
/// <param name="pastTokensCount"></param> /// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns> /// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
public int Eval(List<llama_token> tokens, int pastTokensCount)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens); var span = CollectionsMarshal.AsSpan(tokens);


+ 9
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -6,6 +6,7 @@ using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native;


namespace LLama namespace LLama
{ {
@@ -115,12 +116,16 @@ namespace LLama
break; break;


// when run out of context // when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
if (n_past + tokens.Count > Context.ContextSize)
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{ {
var n_left = n_past - inferenceParams.TokensKeep;
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;


n_past = Math.Max(1, inferenceParams.TokensKeep);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;


tokens.Clear(); tokens.Clear();
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));


+ 1
- 1
LLama/Native/LLamaPos.cs View File

@@ -11,5 +11,5 @@ public record struct LLamaPos


public static explicit operator int(LLamaPos pos) => pos.Value; public static explicit operator int(LLamaPos pos) => pos.Value;


public static explicit operator LLamaPos(int value) => new(value);
public static implicit operator LLamaPos(int value) => new(value);
} }

+ 2
- 0
LLama/Native/NativeApi.cs View File

@@ -547,6 +547,7 @@ namespace LLama.Native
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br /> /// - &lt; 0: error<br />
/// </returns> /// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);


/// <summary> /// <summary>
@@ -556,6 +557,7 @@ namespace LLama.Native
/// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param> /// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param> /// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
} }
} }

+ 4
- 2
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -22,7 +22,7 @@ namespace LLama.Native
/// <summary> /// <summary>
/// Total number of tokens in the context /// Total number of tokens in the context
/// </summary> /// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;
public int ContextSize => NativeApi.llama_n_ctx(this);


/// <summary> /// <summary>
/// Dimension of embedding vectors /// Dimension of embedding vectors
@@ -52,6 +52,8 @@ namespace LLama.Native
_model.DangerousAddRef(ref success); _model.DangerousAddRef(ref success);
if (!success) if (!success)
throw new RuntimeError("Failed to increment model refcount"); throw new RuntimeError("Failed to increment model refcount");

} }


/// <inheritdoc /> /// <inheritdoc />
@@ -238,7 +240,7 @@ namespace LLama.Native
} }
} }


public int Decode(SafeLLamaContextHandle ctx, LLamaBatchSafeHandle batch)
public int Decode(LLamaBatchSafeHandle batch)
{ {
return NativeApi.llama_decode(this, batch.Batch); return NativeApi.llama_decode(this, batch.Batch);
} }


Loading…
Cancel
Save