|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
using System; |
|
|
|
using System.Buffers; |
|
|
|
using System.Collections; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Text; |
|
|
|
using LLama.Native; |
|
|
|
@@ -28,10 +29,11 @@ namespace LLama.Extensions |
|
|
|
/// <param name="model">Model to use to convert tokens into bytes</param> |
|
|
|
/// <param name="encoding">Encoding to use to convert bytes into characters</param> |
|
|
|
/// <returns></returns> |
|
|
|
internal static bool TokensEndsWithAnyString<TList>(this TList tokens, IReadOnlyList<string> queries, SafeLlamaModelHandle model, Encoding encoding) |
|
|
|
where TList : IReadOnlyList<int> |
|
|
|
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) |
|
|
|
where TTokens : IReadOnlyList<int> |
|
|
|
where TQueries : IReadOnlyList<string> |
|
|
|
{ |
|
|
|
if (queries.Count == 0 || tokens.Count == 0) |
|
|
|
if (queries == null || queries.Count == 0 || tokens.Count == 0) |
|
|
|
return false; |
|
|
|
|
|
|
|
// Find the length of the longest query |
|
|
|
@@ -58,5 +60,39 @@ namespace LLama.Extensions |
|
|
|
ArrayPool<char>.Shared.Return(builderArray); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding) |
|
|
|
where TTokens : IReadOnlyList<int> |
|
|
|
{ |
|
|
|
if (queries == null || queries.Count == 0 || tokens.Count == 0) |
|
|
|
return false; |
|
|
|
|
|
|
|
return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding); |
|
|
|
} |
|
|
|
|
|
|
|
private readonly struct ReadonlyWrapper<T> |
|
|
|
: IReadOnlyList<T> |
|
|
|
{ |
|
|
|
private readonly IList<T> _list; |
|
|
|
|
|
|
|
public int Count => _list.Count; |
|
|
|
|
|
|
|
public T this[int index] => _list[index]; |
|
|
|
|
|
|
|
public ReadonlyWrapper(IList<T> list) |
|
|
|
{ |
|
|
|
_list = list; |
|
|
|
} |
|
|
|
|
|
|
|
public IEnumerator<T> GetEnumerator() |
|
|
|
{ |
|
|
|
return _list.GetEnumerator(); |
|
|
|
} |
|
|
|
|
|
|
|
IEnumerator IEnumerable.GetEnumerator() |
|
|
|
{ |
|
|
|
return ((IEnumerable)_list).GetEnumerator(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |