Browse Source

Fixed out-of-context handling in stateless executor

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
0d40338692
6 changed files with 18 additions and 9 deletions
  1. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  2. +1
    -1
      LLama/LLamaContext.cs
  3. +9
    -4
      LLama/LLamaStatelessExecutor.cs
  4. +1
    -1
      LLama/Native/LLamaPos.cs
  5. +2
    -0
      LLama/Native/NativeApi.cs
  6. +4
    -2
      LLama/Native/SafeLLamaContextHandle.cs

+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -54,7 +54,7 @@ namespace LLama.Unittest
// with a modified context
var @params = new InferenceParams()
{
MaxTokens = 100,
MaxTokens = 80,
TokensKeep = question.Length,
};



+ 1
- 1
LLama/LLamaContext.cs View File

@@ -382,7 +382,7 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
public int Eval(List<llama_token> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);


+ 9
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -6,6 +6,7 @@ using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using LLama.Extensions;
using LLama.Native;

namespace LLama
{
@@ -115,12 +116,16 @@ namespace LLama
break;

// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
if (n_past + tokens.Count > Context.ContextSize)
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep;
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;

n_past = Math.Max(1, inferenceParams.TokensKeep);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;

tokens.Clear();
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));


+ 1
- 1
LLama/Native/LLamaPos.cs View File

@@ -11,5 +11,5 @@ public record struct LLamaPos

public static explicit operator int(LLamaPos pos) => pos.Value;

public static explicit operator LLamaPos(int value) => new(value);
public static implicit operator LLamaPos(int value) => new(value);
}

+ 2
- 0
LLama/Native/NativeApi.cs View File

@@ -547,6 +547,7 @@ namespace LLama.Native
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);

/// <summary>
@@ -556,6 +557,7 @@ namespace LLama.Native
/// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
}
}

+ 4
- 2
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -22,7 +22,7 @@ namespace LLama.Native
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;
public int ContextSize => NativeApi.llama_n_ctx(this);

/// <summary>
/// Dimension of embedding vectors
@@ -52,6 +52,8 @@ namespace LLama.Native
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");

}

/// <inheritdoc />
@@ -238,7 +240,7 @@ namespace LLama.Native
}
}

public int Decode(SafeLLamaContextHandle ctx, LLamaBatchSafeHandle batch)
public int Decode(LLamaBatchSafeHandle batch)
{
return NativeApi.llama_decode(this, batch.Batch);
}


Loading…
Cancel
Save