better_instruct_antiprompt_checkingtags/v0.6.0
| @@ -9,6 +9,13 @@ namespace LLama.Extensions | |||||
| { | { | ||||
| internal static class IReadOnlyListExtensions | internal static class IReadOnlyListExtensions | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Find the index of `item` in `list` | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="list">list to search</param> | |||||
| /// <param name="item">item to search for</param> | |||||
| /// <returns></returns> | |||||
| public static int? IndexOf<T>(this IReadOnlyList<T> list, T item) | public static int? IndexOf<T>(this IReadOnlyList<T> list, T item) | ||||
| where T : IEquatable<T> | where T : IEquatable<T> | ||||
| { | { | ||||
| @@ -61,6 +68,14 @@ namespace LLama.Extensions | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Check if the given set of tokens ends with any of the given strings | |||||
| /// </summary> | |||||
| /// <param name="tokens">Tokens to check</param> | |||||
| /// <param name="queries">Strings to search for</param> | |||||
| /// <param name="model">Model to use to convert tokens into bytes</param> | |||||
| /// <param name="encoding">Encoding to use to convert bytes into characters</param> | |||||
| /// <returns></returns> | |||||
| internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding) | internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding) | ||||
| where TTokens : IReadOnlyList<int> | where TTokens : IReadOnlyList<int> | ||||
| { | { | ||||
| @@ -489,6 +489,16 @@ namespace LLama | |||||
| return NativeHandle.TokenToString(token, Encoding); | return NativeHandle.TokenToString(token, Encoding); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Append a single token to a string builder | |||||
| /// </summary> | |||||
| /// <param name="token">Token to decode</param> | |||||
| /// <param name="dest">string builder to append the result to</param> | |||||
| public void TokenToString(llama_token token, StringBuilder dest) | |||||
| { | |||||
| NativeHandle.TokenToString(token, Encoding, dest); | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| @@ -8,6 +8,7 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using LLama.Extensions; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -139,21 +140,10 @@ namespace LLama | |||||
| extraOutputs = null; | extraOutputs = null; | ||||
| if (_embed_inps.Count <= _consumedTokensCount) | if (_embed_inps.Count <= _consumedTokensCount) | ||||
| { | { | ||||
| if (args.Antiprompts is not null && args.Antiprompts.Count > 0) | |||||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||||
| { | { | ||||
| var last_output_builder = new StringBuilder(); | |||||
| foreach (var token in _last_n_tokens) | |||||
| Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); | |||||
| var last_output = last_output_builder.ToString(); | |||||
| foreach (var antiprompt in args.Antiprompts) | |||||
| { | |||||
| if (last_output.EndsWith(antiprompt)) | |||||
| { | |||||
| args.WaitForInput = true; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| args.WaitForInput = true; | |||||
| return true; | |||||
| } | } | ||||
| if (_pastTokensCount > 0 && args.WaitForInput) | if (_pastTokensCount > 0 && args.WaitForInput) | ||||