You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StreamingTokenDecoder.cs 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. using System.Buffers;
  2. using System.Diagnostics;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Extensions;
  7. using LLama.Native;
  8. namespace LLama.Transform
  9. {
  10. /// <summary>
  11. /// Decodes a stream of tokens into a stream of characters
  12. /// </summary>
  13. public sealed class StreamingTokenDecoder
  14. {
  15. private readonly SafeLlamaModelHandle? _weights;
  16. private readonly Decoder _decoder;
  17. private readonly List<char> _characters = new();
  18. /// <summary>
  19. /// The number of decoded characters waiting to be read
  20. /// </summary>
  21. public int AvailableCharacters => _characters.Count;
  22. #region constructors
  23. /// <summary>
  24. /// Create a new decoder
  25. /// </summary>
  26. /// <param name="encoding">Text encoding to use</param>
  27. /// <param name="weights">Model weights</param>
  28. public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null)
  29. : this(encoding, weights?.NativeHandle)
  30. {
  31. }
  32. /// <summary>
  33. /// Create a new decoder
  34. /// </summary>
  35. /// <param name="context">Context to retrieve encoding and model weights from</param>
  36. public StreamingTokenDecoder(LLamaContext context)
  37. : this(context.Encoding, context.NativeHandle)
  38. {
  39. }
  40. /// <summary>
  41. /// Create a new decoder
  42. /// </summary>
  43. /// <param name="encoding">Text encoding to use</param>
  44. /// <param name="context">Context to retrieve model weights from</param>
  45. public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
  46. : this(encoding, context.ModelHandle)
  47. {
  48. }
  49. /// <summary>
  50. /// Create a new decoder
  51. /// </summary>
  52. /// <param name="encoding">Text encoding to use</param>
  53. /// <param name="weights">Models weights to use</param>
  54. public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
  55. {
  56. _weights = weights;
  57. _decoder = encoding.GetDecoder();
  58. }
  59. #endregion
  60. /// <summary>
  61. /// Add a single token to the decoder
  62. /// </summary>
  63. /// <param name="token"></param>
  64. public void Add(int token, SafeLlamaModelHandle? weights = null)
  65. {
  66. weights ??= _weights;
  67. if(weights is null)
  68. {
  69. throw new NullReferenceException("No weights provided for StreamingTokenDecoder.");
  70. }
  71. var charsArr = ArrayPool<char>.Shared.Rent(16);
  72. var bytesArr = ArrayPool<byte>.Shared.Rent(16);
  73. try
  74. {
  75. // Convert this token into bytes
  76. var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length;
  77. // Convert those bytes into characters
  78. var bytesOffset = 0;
  79. var completed = false;
  80. while (!completed)
  81. {
  82. // Decode some of the bytes into the temp char buffer. Keep doing this
  83. // until all bytes have been consumed
  84. _decoder.Convert(
  85. bytesArr, bytesOffset, bytesAvailable,
  86. charsArr, 0, charsArr.Length,
  87. false,
  88. out var bytesUsed, out var charsUsed, out completed
  89. );
  90. bytesOffset += bytesUsed;
  91. bytesAvailable -= bytesUsed;
  92. // Add the decoded characters to the output buffer
  93. _characters.AddSpan(charsArr.AsSpan(0, charsUsed));
  94. }
  95. }
  96. finally
  97. {
  98. ArrayPool<char>.Shared.Return(charsArr);
  99. ArrayPool<byte>.Shared.Return(bytesArr);
  100. }
  101. return;
  102. // Converts a single token into bytes, using the `bytes` array as temporary storage.
  103. // If the `bytes` array is too small it will get a larger one from the ArrayPool.
  104. static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
  105. {
  106. // Try to get bytes
  107. var l = model.TokenToSpan(token, bytes);
  108. // Negative length indicates that the output was too small. Expand it to twice that size and try again.
  109. if (l < 0)
  110. {
  111. // Return the old array to the pool and get a new one
  112. ArrayPool<byte>.Shared.Return(bytes);
  113. bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
  114. // Get bytes, this time it can't fail
  115. l = model.TokenToSpan(token, bytes);
  116. }
  117. Debug.Assert(l >= 0);
  118. return new Span<byte>(bytes, 0, l);
  119. }
  120. }
  121. /// <summary>
  122. /// Add all tokens in the given enumerable
  123. /// </summary>
  124. /// <param name="tokens"></param>
  125. public void AddRange(IEnumerable<int> tokens, SafeLlamaModelHandle? weights = null)
  126. {
  127. foreach (var item in tokens)
  128. Add(item, weights);
  129. }
  130. /// <summary>
  131. /// Read all decoded characters and clear the buffer
  132. /// </summary>
  133. /// <param name="dest"></param>
  134. public void Read(List<char> dest)
  135. {
  136. dest.AddRange(_characters);
  137. _characters.Clear();
  138. }
  139. /// <summary>
  140. /// Read all decoded characters as a string and clear the buffer
  141. /// </summary>
  142. /// <returns></returns>
  143. public string Read()
  144. {
  145. if (_characters.Count == 0)
  146. return "";
  147. var str = string.Join("", _characters);
  148. _characters.Clear();
  149. return str;
  150. }
  151. /// <summary>
  152. /// Set the decoder back to its initial state
  153. /// </summary>
  154. public void Reset()
  155. {
  156. _decoder.Reset();
  157. _characters.Clear();
  158. }
  159. }
  160. }