| @@ -15,6 +15,8 @@ namespace LLama.Common | |||||
| private readonly int _maxSize; | private readonly int _maxSize; | ||||
| private readonly List<T> _storage; | private readonly List<T> _storage; | ||||
| internal IReadOnlyList<T> Items => _storage; | |||||
| /// <summary> | /// <summary> | ||||
| /// Number of items in this queue | /// Number of items in this queue | ||||
| /// </summary> | /// </summary> | ||||
| @@ -57,6 +59,7 @@ namespace LLama.Common | |||||
| if (_storage.Count > _maxSize) | if (_storage.Count > _maxSize) | ||||
| throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); | throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Replace every item in the queue with the given value | /// Replace every item in the queue with the given value | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| @@ -28,10 +29,11 @@ namespace LLama.Extensions | |||||
| /// <param name="model">Model to use to convert tokens into bytes</param> | /// <param name="model">Model to use to convert tokens into bytes</param> | ||||
| /// <param name="encoding">Encoding to use to convert bytes into characters</param> | /// <param name="encoding">Encoding to use to convert bytes into characters</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| internal static bool TokensEndsWithAnyString<TList>(this TList tokens, IReadOnlyList<string> queries, SafeLlamaModelHandle model, Encoding encoding) | |||||
| where TList : IReadOnlyList<int> | |||||
| internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) | |||||
| where TTokens : IReadOnlyList<int> | |||||
| where TQueries : IReadOnlyList<string> | |||||
| { | { | ||||
| if (queries.Count == 0 || tokens.Count == 0) | |||||
| if (queries == null || queries.Count == 0 || tokens.Count == 0) | |||||
| return false; | return false; | ||||
| // Find the length of the longest query | // Find the length of the longest query | ||||
| @@ -58,5 +60,39 @@ namespace LLama.Extensions | |||||
| ArrayPool<char>.Shared.Return(builderArray); | ArrayPool<char>.Shared.Return(builderArray); | ||||
| } | } | ||||
| } | } | ||||
| internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding) | |||||
| where TTokens : IReadOnlyList<int> | |||||
| { | |||||
| if (queries == null || queries.Count == 0 || tokens.Count == 0) | |||||
| return false; | |||||
| return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding); | |||||
| } | |||||
| private readonly struct ReadonlyWrapper<T> | |||||
| : IReadOnlyList<T> | |||||
| { | |||||
| private readonly IList<T> _list; | |||||
| public int Count => _list.Count; | |||||
| public T this[int index] => _list[index]; | |||||
| public ReadonlyWrapper(IList<T> list) | |||||
| { | |||||
| _list = list; | |||||
| } | |||||
| public IEnumerator<T> GetEnumerator() | |||||
| { | |||||
| return _list.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return ((IEnumerable)_list).GetEnumerator(); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,6 +8,7 @@ using System.Linq; | |||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Extensions; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -128,27 +129,11 @@ namespace LLama | |||||
| extraOutputs = null; | extraOutputs = null; | ||||
| if (_embed_inps.Count <= _consumedTokensCount) | if (_embed_inps.Count <= _consumedTokensCount) | ||||
| { | { | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||||
| { | |||||
| var last_output_builder = new StringBuilder(); | |||||
| foreach (var token in _last_n_tokens) | |||||
| Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); | |||||
| var last_output = last_output_builder.ToString(); | |||||
| foreach (var antiprompt in args.Antiprompts) | |||||
| { | |||||
| if (last_output.EndsWith(antiprompt)) | |||||
| { | |||||
| args.WaitForInput = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||||
| args.WaitForInput = true; | |||||
| if (_pastTokensCount > 0 && args.WaitForInput) | if (_pastTokensCount > 0 && args.WaitForInput) | ||||
| { | |||||
| return true; | return true; | ||||
| } | |||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | ||||
| @@ -1,11 +1,9 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | |||||
| using System.Threading; | using System.Threading; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||