You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StreamingTokenDecoder.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. using System.Buffers;
  2. using System.Diagnostics;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Extensions;
  7. using LLama.Native;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// Decodes a stream of tokens into a stream of characters
  12. /// </summary>
  13. public sealed class StreamingTokenDecoder
  14. {
  15. private readonly SafeLlamaModelHandle _weights;
  16. private readonly Decoder _decoder;
  17. private readonly List<char> _characters = new();
  18. /// <summary>
  19. /// The number of decoded characters waiting to be read
  20. /// </summary>
  21. public int AvailableCharacters => _characters.Count;
  22. #region constructors
  23. /// <summary>
  24. /// Create a new decoder
  25. /// </summary>
  26. /// <param name="encoding">Text encoding to use</param>
  27. /// <param name="weights">Model weights</param>
  28. public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
  29. : this(encoding, weights.NativeHandle)
  30. {
  31. }
  32. /// <summary>
  33. /// Create a new decoder
  34. /// </summary>
  35. /// <param name="context">Context to retrieve encoding and model weights from</param>
  36. public StreamingTokenDecoder(LLamaContext context)
  37. : this(context.Encoding, context.NativeHandle)
  38. {
  39. }
  40. /// <summary>
  41. /// Create a new decoder
  42. /// </summary>
  43. /// <param name="encoding">Text encoding to use</param>
  44. /// <param name="context">Context to retrieve model weights from</param>
  45. public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
  46. : this(encoding, context.ModelHandle)
  47. {
  48. }
  49. /// <summary>
  50. /// Create a new decoder
  51. /// </summary>
  52. /// <param name="encoding">Text encoding to use</param>
  53. /// <param name="weights">Models weights to use</param>
  54. public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
  55. {
  56. _weights = weights;
  57. _decoder = encoding.GetDecoder();
  58. }
  59. #endregion
  60. /// <summary>
  61. /// Add a single token to the decoder
  62. /// </summary>
  63. /// <param name="token"></param>
  64. public void Add(LLamaToken token)
  65. {
  66. var charsArr = ArrayPool<char>.Shared.Rent(16);
  67. var bytesArr = ArrayPool<byte>.Shared.Rent(16);
  68. try
  69. {
  70. // Convert this token into bytes
  71. var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
  72. // Convert those bytes into characters
  73. var bytesOffset = 0;
  74. var completed = false;
  75. while (!completed)
  76. {
  77. // Decode some of the bytes into the temp char buffer. Keep doing this
  78. // until all bytes have been consumed
  79. _decoder.Convert(
  80. bytesArr, bytesOffset, bytesAvailable,
  81. charsArr, 0, charsArr.Length,
  82. false,
  83. out var bytesUsed, out var charsUsed, out completed
  84. );
  85. bytesOffset += bytesUsed;
  86. bytesAvailable -= bytesUsed;
  87. // Add the decoded characters to the output buffer
  88. _characters.AddSpan(charsArr.AsSpan(0, charsUsed));
  89. }
  90. }
  91. finally
  92. {
  93. ArrayPool<char>.Shared.Return(charsArr);
  94. ArrayPool<byte>.Shared.Return(bytesArr);
  95. }
  96. return;
  97. // Converts a single token into bytes, using the `bytes` array as temporary storage.
  98. // If the `bytes` array is too small it will get a larger one from the ArrayPool.
  99. static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model)
  100. {
  101. // Try to get bytes
  102. var l = model.TokenToSpan(token, bytes);
  103. // Check if the length was larger than the buffer. If so expand the buffer and try again
  104. if (l > bytes.Length)
  105. {
  106. // Return the old array to the pool and get a new one
  107. ArrayPool<byte>.Shared.Return(bytes);
  108. bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));
  109. // Get bytes, this time it can't fail
  110. l = model.TokenToSpan(token, bytes);
  111. }
  112. Debug.Assert(l <= bytes.Length);
  113. return new Span<byte>(bytes, 0, (int)l);
  114. }
  115. }
  116. /// <summary>
  117. /// Add a single token to the decoder
  118. /// </summary>
  119. /// <param name="token"></param>
  120. public void Add(int token)
  121. {
  122. Add((LLamaToken)token);
  123. }
  124. /// <summary>
  125. /// Add all tokens in the given enumerable
  126. /// </summary>
  127. /// <param name="tokens"></param>
  128. public void AddRange<T>(T tokens)
  129. where T : IEnumerable<LLamaToken>
  130. {
  131. foreach (var item in tokens)
  132. Add((int)item);
  133. }
  134. /// <summary>
  135. /// Add all tokens in the given span
  136. /// </summary>
  137. /// <param name="tokens"></param>
  138. public void AddRange(ReadOnlySpan<LLamaToken> tokens)
  139. {
  140. foreach (var item in tokens)
  141. Add(item);
  142. }
  143. /// <summary>
  144. /// Read all decoded characters and clear the buffer
  145. /// </summary>
  146. /// <param name="dest"></param>
  147. public void Read(List<char> dest)
  148. {
  149. dest.AddRange(_characters);
  150. _characters.Clear();
  151. }
  152. /// <summary>
  153. /// Read all decoded characters as a string and clear the buffer
  154. /// </summary>
  155. /// <returns></returns>
  156. public string Read()
  157. {
  158. if (_characters.Count == 0)
  159. return "";
  160. var str = string.Join("", _characters);
  161. _characters.Clear();
  162. return str;
  163. }
  164. /// <summary>
  165. /// Set the decoder back to its initial state
  166. /// </summary>
  167. public void Reset()
  168. {
  169. _decoder.Reset();
  170. _characters.Clear();
  171. }
  172. }
  173. }