| @@ -54,7 +54,7 @@ namespace LLama.Unittest | |||||
| // with a modified context | // with a modified context | ||||
| var @params = new InferenceParams() | var @params = new InferenceParams() | ||||
| { | { | ||||
| MaxTokens = 100, | |||||
| MaxTokens = 80, | |||||
| TokensKeep = question.Length, | TokensKeep = question.Length, | ||||
| }; | }; | ||||
| @@ -382,7 +382,7 @@ namespace LLama | |||||
| /// <param name="pastTokensCount"></param> | /// <param name="pastTokensCount"></param> | ||||
| /// <returns>The updated `pastTokensCount`.</returns> | /// <returns>The updated `pastTokensCount`.</returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public int Eval(List<llama_token> tokens, llama_token pastTokensCount) | |||||
| public int Eval(List<llama_token> tokens, int pastTokensCount) | |||||
| { | { | ||||
| #if NET5_0_OR_GREATER | #if NET5_0_OR_GREATER | ||||
| var span = CollectionsMarshal.AsSpan(tokens); | var span = CollectionsMarshal.AsSpan(tokens); | ||||
| @@ -6,6 +6,7 @@ using System.Linq; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Threading; | using System.Threading; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Native; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -115,12 +116,16 @@ namespace LLama | |||||
| break; | break; | ||||
| // when run out of context | // when run out of context | ||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 | |||||
| if (n_past + tokens.Count > Context.ContextSize) | |||||
| // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 | |||||
| if (n_past + tokens.Count >= Context.ContextSize) | |||||
| { | { | ||||
| var n_left = n_past - inferenceParams.TokensKeep; | |||||
| var n_left = n_past - inferenceParams.TokensKeep - 1; | |||||
| var n_discard = n_left / 2; | |||||
| n_past = Math.Max(1, inferenceParams.TokensKeep); | |||||
| NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); | |||||
| NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); | |||||
| n_past -= n_discard; | |||||
| tokens.Clear(); | tokens.Clear(); | ||||
| tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); | ||||
| @@ -11,5 +11,5 @@ public record struct LLamaPos | |||||
| public static explicit operator int(LLamaPos pos) => pos.Value; | public static explicit operator int(LLamaPos pos) => pos.Value; | ||||
| public static explicit operator LLamaPos(int value) => new(value); | |||||
| public static implicit operator LLamaPos(int value) => new(value); | |||||
| } | } | ||||
| @@ -547,6 +547,7 @@ namespace LLama.Native | |||||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | ||||
| /// - < 0: error<br /> | /// - < 0: error<br /> | ||||
| /// </returns> | /// </returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); | public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -556,6 +557,7 @@ namespace LLama.Native | |||||
| /// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param> | /// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param> | ||||
| /// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param> | /// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); | public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,7 +22,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// Total number of tokens in the context | /// Total number of tokens in the context | ||||
| /// </summary> | /// </summary> | ||||
| public int ContextSize => ThrowIfDisposed().ContextSize; | |||||
| public int ContextSize => NativeApi.llama_n_ctx(this); | |||||
| /// <summary> | /// <summary> | ||||
| /// Dimension of embedding vectors | /// Dimension of embedding vectors | ||||
| @@ -52,6 +52,8 @@ namespace LLama.Native | |||||
| _model.DangerousAddRef(ref success); | _model.DangerousAddRef(ref success); | ||||
| if (!success) | if (!success) | ||||
| throw new RuntimeError("Failed to increment model refcount"); | throw new RuntimeError("Failed to increment model refcount"); | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -238,7 +240,7 @@ namespace LLama.Native | |||||
| } | } | ||||
| } | } | ||||
| public int Decode(SafeLLamaContextHandle ctx, LLamaBatchSafeHandle batch) | |||||
| public int Decode(LLamaBatchSafeHandle batch) | |||||
| { | { | ||||
| return NativeApi.llama_decode(this, batch.Batch); | return NativeApi.llama_decode(this, batch.Batch); | ||||
| } | } | ||||