You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StreamingTokenDecoder.cs 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. using System.Buffers;
  2. using System.Diagnostics;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Extensions;
  7. using LLama.Native;
  8. namespace LLama;
  9. /// <summary>
  10. /// Decodes a stream of tokens into a stream of characters
  11. /// </summary>
  12. public sealed class StreamingTokenDecoder
  13. {
  14. private readonly SafeLlamaModelHandle _weights;
  15. private readonly Decoder _decoder;
  16. private readonly List<char> _characters = new();
  17. /// <summary>
  18. /// The number of decoded characters waiting to be read
  19. /// </summary>
  20. public int AvailableCharacters => _characters.Count;
  21. #region constructors
  22. /// <summary>
  23. /// Create a new decoder
  24. /// </summary>
  25. /// <param name="encoding">Text encoding to use</param>
  26. /// <param name="weights">Model weights</param>
  27. public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
  28. : this(encoding, weights.NativeHandle)
  29. {
  30. }
  31. /// <summary>
  32. /// Create a new decoder
  33. /// </summary>
  34. /// <param name="context">Context to retrieve encoding and model weights from</param>
  35. public StreamingTokenDecoder(LLamaContext context)
  36. : this(context.Encoding, context.NativeHandle)
  37. {
  38. }
  39. /// <summary>
  40. /// Create a new decoder
  41. /// </summary>
  42. /// <param name="encoding">Text encoding to use</param>
  43. /// <param name="context">Context to retrieve model weights from</param>
  44. public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
  45. : this(encoding, context.ModelHandle)
  46. {
  47. }
  48. /// <summary>
  49. /// Create a new decoder
  50. /// </summary>
  51. /// <param name="encoding">Text encoding to use</param>
  52. /// <param name="weights">Models weights to use</param>
  53. public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
  54. {
  55. _weights = weights;
  56. _decoder = encoding.GetDecoder();
  57. }
  58. #endregion
  59. /// <summary>
  60. /// Add a single token to the decoder
  61. /// </summary>
  62. /// <param name="token"></param>
  63. public void Add(int token)
  64. {
  65. var charsArr = ArrayPool<char>.Shared.Rent(16);
  66. var bytesArr = ArrayPool<byte>.Shared.Rent(16);
  67. try
  68. {
  69. // Convert this token into bytes
  70. var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
  71. // Convert those bytes into characters
  72. var bytesOffset = 0;
  73. var completed = false;
  74. while (!completed)
  75. {
  76. // Decode some of the bytes into the temp char buffer. Keep doing this
  77. // until all bytes have been consumed
  78. _decoder.Convert(
  79. bytesArr, bytesOffset, bytesAvailable,
  80. charsArr, 0, charsArr.Length,
  81. false,
  82. out var bytesUsed, out var charsUsed, out completed
  83. );
  84. bytesOffset += bytesUsed;
  85. bytesAvailable -= bytesUsed;
  86. // Add the decoded characters to the output buffer
  87. _characters.AddSpan(charsArr.AsSpan(0, charsUsed));
  88. }
  89. }
  90. finally
  91. {
  92. ArrayPool<char>.Shared.Return(charsArr);
  93. ArrayPool<byte>.Shared.Return(bytesArr);
  94. }
  95. return;
  96. // Converts a single token into bytes, using the `bytes` array as temporary storage.
  97. // If the `bytes` array is too small it will get a larger one from the ArrayPool.
  98. static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
  99. {
  100. // Try to get bytes
  101. var l = model.TokenToSpan(token, bytes);
  102. // Negative length indicates that the output was too small. Expand it to twice that size and try again.
  103. if (l < 0)
  104. {
  105. // Return the old array to the pool and get a new one
  106. ArrayPool<byte>.Shared.Return(bytes);
  107. bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
  108. // Get bytes, this time it can't fail
  109. l = model.TokenToSpan(token, bytes);
  110. }
  111. Debug.Assert(l >= 0);
  112. return new Span<byte>(bytes, 0, l);
  113. }
  114. }
  115. /// <summary>
  116. /// Add all tokens in the given enumerable
  117. /// </summary>
  118. /// <param name="tokens"></param>
  119. public void AddRange(IEnumerable<int> tokens)
  120. {
  121. foreach (var item in tokens)
  122. Add(item);
  123. }
  124. /// <summary>
  125. /// Read all decoded characters and clear the buffer
  126. /// </summary>
  127. /// <param name="dest"></param>
  128. public void Read(List<char> dest)
  129. {
  130. dest.AddRange(_characters);
  131. _characters.Clear();
  132. }
  133. /// <summary>
  134. /// Read all decoded characters as a string and clear the buffer
  135. /// </summary>
  136. /// <returns></returns>
  137. public string Read()
  138. {
  139. if (_characters.Count == 0)
  140. return "";
  141. var str = string.Join("", _characters);
  142. _characters.Clear();
  143. return str;
  144. }
  145. /// <summary>
  146. /// Set the decoder back to its initial state
  147. /// </summary>
  148. public void Reset()
  149. {
  150. _decoder.Reset();
  151. _characters.Clear();
  152. }
  153. }