From 77bd090150e196bddcbdddf7bbf253c98e1aca35 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 6 Sep 2023 22:26:36 +0100 Subject: [PATCH] Simplified `LLamaInteractExecutor` antiprompt matching by using new extension method --- LLama/Common/FixedSizeQueue.cs | 3 ++ LLama/Extensions/IReadOnlyListExtensions.cs | 42 +++++++++++++++++++-- LLama/LLamaInteractExecutor.cs | 21 ++--------- LLama/LLamaStatelessExecutor.cs | 2 - 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 2c331e5a..97a4d6ee 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -15,6 +15,8 @@ namespace LLama.Common private readonly int _maxSize; private readonly List _storage; + internal IReadOnlyList Items => _storage; + /// /// Number of items in this queue /// @@ -57,6 +59,7 @@ namespace LLama.Common if (_storage.Count > _maxSize) throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); } + /// /// Replace every item in the queue with the given value /// diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index c1e5eb57..b07d90cf 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Collections; using System.Collections.Generic; using System.Text; using LLama.Native; @@ -28,10 +29,11 @@ namespace LLama.Extensions /// Model to use to convert tokens into bytes /// Encoding to use to convert bytes into characters /// - internal static bool TokensEndsWithAnyString(this TList tokens, IReadOnlyList queries, SafeLlamaModelHandle model, Encoding encoding) - where TList : IReadOnlyList + internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) + where TTokens : IReadOnlyList + where TQueries : IReadOnlyList { - if (queries.Count == 0 || tokens.Count == 0) + if (queries == null || queries.Count == 0 || tokens.Count == 0) return false; // Find the length of the longest query @@ -58,5 +60,39 @@ namespace LLama.Extensions ArrayPool.Shared.Return(builderArray); } } + + internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) + where TTokens : IReadOnlyList + { + if (queries == null || queries.Count == 0 || tokens.Count == 0) + return false; + + return tokens.TokensEndsWithAnyString(new ReadonlyWrapper(queries), model, encoding); + } + + private readonly struct ReadonlyWrapper + : IReadOnlyList + { + private readonly IList _list; + + public int Count => _list.Count; + + public T this[int index] => _list[index]; + + public ReadonlyWrapper(IList list) + { + _list = list; + } + + public IEnumerator GetEnumerator() + { + return _list.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_list).GetEnumerator(); + } + } } } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4f29b998..1dafbfb6 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using System.Text; +using LLama.Extensions; namespace LLama { @@ -128,27 +129,11 @@ namespace LLama extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { - if (args.Antiprompts is not null && args.Antiprompts.Count > 0) - { - var last_output_builder = new StringBuilder(); - foreach (var token in _last_n_tokens) - Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); - var last_output = last_output_builder.ToString(); - - foreach (var antiprompt in args.Antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - args.WaitForInput = true; - break; - } - } - } + if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + args.WaitForInput = true; if (_pastTokensCount > 0 && args.WaitForInput) - { return true; - } } if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 6b213f16..5c496037 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,11 +1,9 @@ using LLama.Abstractions; using LLama.Common; using System; -using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using LLama.Extensions;