using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using LLama.Native;
namespace LLama.Extensions
{
internal static class IReadOnlyListExtensions
{
///
/// Find the index of `item` in `list`
///
///
/// list to search
/// item to search for
///
public static int? IndexOf(this IReadOnlyList list, T item)
where T : IEquatable
{
for (var i = 0; i < list.Count; i++)
{
if (list[i].Equals(item))
return i;
}
return null;
}
///
/// Check if the given set of tokens ends with any of the given strings
///
/// Tokens to check
/// Strings to search for
/// Model to use to convert tokens into bytes
/// Encoding to use to convert bytes into characters
///
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList
where TQueries : IReadOnlyList
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;
// Find the length of the longest query
var longest = 0;
foreach (var candidate in queries)
longest = Math.Max(longest, candidate.Length);
// Rent an array to detokenize into
var builderArray = ArrayPool.Shared.Rent(longest);
try
{
// Convert as many tokens as possible into the builderArray
var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding);
// Check every query to see if it's present
foreach (var query in queries)
if (characters.EndsWith(query.AsSpan()))
return true;
return false;
}
finally
{
ArrayPool.Shared.Return(builderArray);
}
}
///
/// Check if the given set of tokens ends with any of the given strings
///
/// Tokens to check
/// Strings to search for
/// Model to use to convert tokens into bytes
/// Encoding to use to convert bytes into characters
///
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;
return tokens.TokensEndsWithAnyString(new ReadonlyWrapper(queries), model, encoding);
}
private readonly struct ReadonlyWrapper
: IReadOnlyList
{
private readonly IList _list;
public int Count => _list.Count;
public T this[int index] => _list[index];
public ReadonlyWrapper(IList list)
{
_list = list;
}
public IEnumerator GetEnumerator()
{
return _list.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_list).GetEnumerator();
}
}
}
}