You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IReadOnlyListExtensions.cs 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using System;
  2. using System.Buffers;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Native;
  7. namespace LLama.Extensions
  8. {
  9. internal static class IReadOnlyListExtensions
  10. {
  11. /// <summary>
  12. /// Find the index of `item` in `list`
  13. /// </summary>
  14. /// <typeparam name="T"></typeparam>
  15. /// <param name="list">list to search</param>
  16. /// <param name="item">item to search for</param>
  17. /// <returns></returns>
  18. public static int? IndexOf<T>(this IReadOnlyList<T> list, T item)
  19. where T : IEquatable<T>
  20. {
  21. for (var i = 0; i < list.Count; i++)
  22. {
  23. if (list[i].Equals(item))
  24. return i;
  25. }
  26. return null;
  27. }
  28. /// <summary>
  29. /// Check if the given set of tokens ends with any of the given strings
  30. /// </summary>
  31. /// <param name="tokens">Tokens to check</param>
  32. /// <param name="queries">Strings to search for</param>
  33. /// <param name="model">Model to use to convert tokens into bytes</param>
  34. /// <param name="encoding">Encoding to use to convert bytes into characters</param>
  35. /// <returns></returns>
  36. internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
  37. where TTokens : IReadOnlyList<int>
  38. where TQueries : IReadOnlyList<string>
  39. {
  40. if (queries == null || queries.Count == 0 || tokens.Count == 0)
  41. return false;
  42. // Find the length of the longest query
  43. var longest = 0;
  44. foreach (var candidate in queries)
  45. longest = Math.Max(longest, candidate.Length);
  46. // Rent an array to detokenize into
  47. var builderArray = ArrayPool<char>.Shared.Rent(longest);
  48. try
  49. {
  50. // Convert as many tokens as possible into the builderArray
  51. var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding);
  52. // Check every query to see if it's present
  53. foreach (var query in queries)
  54. if (characters.EndsWith(query.AsSpan()))
  55. return true;
  56. return false;
  57. }
  58. finally
  59. {
  60. ArrayPool<char>.Shared.Return(builderArray);
  61. }
  62. }
  63. internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
  64. where TTokens : IReadOnlyList<int>
  65. where TQueries : IReadOnlyList<string>
  66. {
  67. return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
  68. }
  69. /// <summary>
  70. /// Check if the given set of tokens ends with any of the given strings
  71. /// </summary>
  72. /// <param name="tokens">Tokens to check</param>
  73. /// <param name="queries">Strings to search for</param>
  74. /// <param name="model">Model to use to convert tokens into bytes</param>
  75. /// <param name="encoding">Encoding to use to convert bytes into characters</param>
  76. /// <returns></returns>
  77. internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
  78. where TTokens : IReadOnlyList<int>
  79. {
  80. if (queries == null || queries.Count == 0 || tokens.Count == 0)
  81. return false;
  82. return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding);
  83. }
  84. private readonly struct ReadonlyWrapper<T>
  85. : IReadOnlyList<T>
  86. {
  87. private readonly IList<T> _list;
  88. public int Count => _list.Count;
  89. public T this[int index] => _list[index];
  90. public ReadonlyWrapper(IList<T> list)
  91. {
  92. _list = list;
  93. }
  94. public IEnumerator<T> GetEnumerator()
  95. {
  96. return _list.GetEnumerator();
  97. }
  98. IEnumerator IEnumerable.GetEnumerator()
  99. {
  100. return ((IEnumerable)_list).GetEnumerator();
  101. }
  102. }
  103. }
  104. }