You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StreamingTokenDecoder.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. using System.Buffers;
  2. using System.Diagnostics;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Extensions;
  7. using LLama.Native;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// Decodes a stream of tokens into a stream of characters
  12. /// </summary>
  13. public sealed class StreamingTokenDecoder
  14. {
  15. private readonly SafeLlamaModelHandle _weights;
  16. private readonly Decoder _decoder;
  17. private readonly List<char> _characters = new();
  18. /// <summary>
  19. /// The number of decoded characters waiting to be read
  20. /// </summary>
  21. public int AvailableCharacters => _characters.Count;
  22. #region constructors
  23. /// <summary>
  24. /// Create a new decoder
  25. /// </summary>
  26. /// <param name="encoding">Text encoding to use</param>
  27. /// <param name="weights">Model weights</param>
  28. public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
  29. : this(encoding, weights.NativeHandle)
  30. {
  31. }
  32. /// <summary>
  33. /// Create a new decoder
  34. /// </summary>
  35. /// <param name="context">Context to retrieve encoding and model weights from</param>
  36. public StreamingTokenDecoder(LLamaContext context)
  37. : this(context.Encoding, context.NativeHandle)
  38. {
  39. }
  40. /// <summary>
  41. /// Create a new decoder
  42. /// </summary>
  43. /// <param name="encoding">Text encoding to use</param>
  44. /// <param name="context">Context to retrieve model weights from</param>
  45. public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
  46. : this(encoding, context.ModelHandle)
  47. {
  48. }
  49. /// <summary>
  50. /// Create a new decoder
  51. /// </summary>
  52. /// <param name="encoding">Text encoding to use</param>
  53. /// <param name="weights">Models weights to use</param>
  54. public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
  55. {
  56. _weights = weights;
  57. _decoder = encoding.GetDecoder();
  58. }
  59. #endregion
  60. /// <summary>
  61. /// Add a single token to the decoder
  62. /// </summary>
  63. /// <param name="token"></param>
  64. public void Add(LLamaToken token)
  65. {
  66. var charsArr = ArrayPool<char>.Shared.Rent(16);
  67. var bytesArr = ArrayPool<byte>.Shared.Rent(16);
  68. try
  69. {
  70. // Convert this token into bytes
  71. var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
  72. // Convert those bytes into characters
  73. var bytesOffset = 0;
  74. var completed = false;
  75. while (!completed)
  76. {
  77. // Decode some of the bytes into the temp char buffer. Keep doing this
  78. // until all bytes have been consumed
  79. _decoder.Convert(
  80. bytesArr, bytesOffset, bytesAvailable,
  81. charsArr, 0, charsArr.Length,
  82. false,
  83. out var bytesUsed, out var charsUsed, out completed
  84. );
  85. bytesOffset += bytesUsed;
  86. bytesAvailable -= bytesUsed;
  87. // Add the decoded characters to the output buffer
  88. _characters.AddSpan(charsArr.AsSpan(0, charsUsed));
  89. }
  90. }
  91. finally
  92. {
  93. ArrayPool<char>.Shared.Return(charsArr);
  94. ArrayPool<byte>.Shared.Return(bytesArr);
  95. }
  96. return;
  97. // Converts a single token into bytes, using the `bytes` array as temporary storage.
  98. // If the `bytes` array is too small it will get a larger one from the ArrayPool.
  99. static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model)
  100. {
  101. // Try to get bytes
  102. var l = model.TokenToSpan(token, bytes);
  103. // Negative length indicates that the output was too small. Expand it to twice that size and try again.
  104. if (l < 0)
  105. {
  106. // Return the old array to the pool and get a new one
  107. ArrayPool<byte>.Shared.Return(bytes);
  108. bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
  109. // Get bytes, this time it can't fail
  110. l = model.TokenToSpan(token, bytes);
  111. }
  112. Debug.Assert(l >= 0);
  113. return new Span<byte>(bytes, 0, l);
  114. }
  115. }
  116. /// <summary>
  117. /// Add a single token to the decoder
  118. /// </summary>
  119. /// <param name="token"></param>
  120. public void Add(int token)
  121. {
  122. Add((LLamaToken)token);
  123. }
  124. /// <summary>
  125. /// Add all tokens in the given enumerable
  126. /// </summary>
  127. /// <param name="tokens"></param>
  128. public void AddRange(IEnumerable<int> tokens)
  129. {
  130. foreach (var item in tokens)
  131. Add(item);
  132. }
  133. /// <summary>
  134. /// Add all tokens in the given enumerable
  135. /// </summary>
  136. /// <param name="tokens"></param>
  137. public void AddRange(IEnumerable<LLamaToken> tokens)
  138. {
  139. foreach (var item in tokens)
  140. Add((int)item);
  141. }
  142. /// <summary>
  143. /// Read all decoded characters and clear the buffer
  144. /// </summary>
  145. /// <param name="dest"></param>
  146. public void Read(List<char> dest)
  147. {
  148. dest.AddRange(_characters);
  149. _characters.Clear();
  150. }
  151. /// <summary>
  152. /// Read all decoded characters as a string and clear the buffer
  153. /// </summary>
  154. /// <returns></returns>
  155. public string Read()
  156. {
  157. if (_characters.Count == 0)
  158. return "";
  159. var str = string.Join("", _characters);
  160. _characters.Clear();
  161. return str;
  162. }
  163. /// <summary>
  164. /// Set the decoder back to its initial state
  165. /// </summary>
  166. public void Reset()
  167. {
  168. _decoder.Reset();
  169. _characters.Clear();
  170. }
  171. }
  172. }