You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StreamingTokenDecoder.cs 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. using System.Buffers;
  2. using System.Diagnostics;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using LLama.Extensions;
  7. using LLama.Native;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// Decodes a stream of tokens into a stream of characters
  12. /// </summary>
  13. public sealed class StreamingTokenDecoder
  14. {
  15. private readonly SafeLlamaModelHandle _weights;
  16. private readonly Decoder _decoder;
  17. private readonly List<char> _characters = new();
  18. /// <summary>
  19. /// The number of decoded characters waiting to be read
  20. /// </summary>
  21. public int AvailableCharacters => _characters.Count;
  22. #region constructors
  23. /// <summary>
  24. /// Create a new decoder
  25. /// </summary>
  26. /// <param name="encoding">Text encoding to use</param>
  27. /// <param name="weights">Model weights</param>
  28. public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
  29. : this(encoding, weights.NativeHandle)
  30. {
  31. }
  32. /// <summary>
  33. /// Create a new decoder
  34. /// </summary>
  35. /// <param name="context">Context to retrieve encoding and model weights from</param>
  36. public StreamingTokenDecoder(LLamaContext context)
  37. : this(context.Encoding, context.NativeHandle)
  38. {
  39. }
  40. /// <summary>
  41. /// Create a new decoder
  42. /// </summary>
  43. /// <param name="encoding">Text encoding to use</param>
  44. /// <param name="context">Context to retrieve model weights from</param>
  45. public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
  46. : this(encoding, context.ModelHandle)
  47. {
  48. }
  49. /// <summary>
  50. /// Create a new decoder
  51. /// </summary>
  52. /// <param name="encoding">Text encoding to use</param>
  53. /// <param name="weights">Models weights to use</param>
  54. public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
  55. {
  56. _weights = weights;
  57. _decoder = encoding.GetDecoder();
  58. }
  59. #endregion
  60. /// <summary>
  61. /// Add a single token to the decoder
  62. /// </summary>
  63. /// <param name="token"></param>
  64. public void Add(int token)
  65. {
  66. var charsArr = ArrayPool<char>.Shared.Rent(16);
  67. var bytesArr = ArrayPool<byte>.Shared.Rent(16);
  68. try
  69. {
  70. // Convert this token into bytes
  71. var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
  72. // Convert those bytes into characters
  73. var bytesOffset = 0;
  74. var completed = false;
  75. while (!completed)
  76. {
  77. // Decode some of the bytes into the temp char buffer. Keep doing this
  78. // until all bytes have been consumed
  79. _decoder.Convert(
  80. bytesArr, bytesOffset, bytesAvailable,
  81. charsArr, 0, charsArr.Length,
  82. false,
  83. out var bytesUsed, out var charsUsed, out completed
  84. );
  85. bytesOffset += bytesUsed;
  86. bytesAvailable -= bytesUsed;
  87. // Add the decoded characters to the output buffer
  88. _characters.AddSpan(charsArr.AsSpan(0, charsUsed));
  89. }
  90. }
  91. finally
  92. {
  93. ArrayPool<char>.Shared.Return(charsArr);
  94. ArrayPool<byte>.Shared.Return(bytesArr);
  95. }
  96. return;
  97. // Converts a single token into bytes, using the `bytes` array as temporary storage.
  98. // If the `bytes` array is too small it will get a larger one from the ArrayPool.
  99. static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
  100. {
  101. // Try to get bytes
  102. var l = model.TokenToSpan(token, bytes);
  103. // Negative length indicates that the output was too small. Expand it to twice that size and try again.
  104. if (l < 0)
  105. {
  106. // Return the old array to the pool and get a new one
  107. ArrayPool<byte>.Shared.Return(bytes);
  108. bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
  109. // Get bytes, this time it can't fail
  110. l = model.TokenToSpan(token, bytes);
  111. }
  112. Debug.Assert(l >= 0);
  113. return new Span<byte>(bytes, 0, l);
  114. }
  115. }
  116. /// <summary>
  117. /// Add all tokens in the given enumerable
  118. /// </summary>
  119. /// <param name="tokens"></param>
  120. public void AddRange(IEnumerable<int> tokens)
  121. {
  122. foreach (var item in tokens)
  123. Add(item);
  124. }
  125. /// <summary>
  126. /// Read all decoded characters and clear the buffer
  127. /// </summary>
  128. /// <param name="dest"></param>
  129. public void Read(List<char> dest)
  130. {
  131. dest.AddRange(_characters);
  132. _characters.Clear();
  133. }
  134. /// <summary>
  135. /// Read all decoded characters as a string and clear the buffer
  136. /// </summary>
  137. /// <returns></returns>
  138. public string Read()
  139. {
  140. if (_characters.Count == 0)
  141. return "";
  142. var str = string.Join("", _characters);
  143. _characters.Clear();
  144. return str;
  145. }
  146. /// <summary>
  147. /// Set the decoder back to its initial state
  148. /// </summary>
  149. public void Reset()
  150. {
  151. _decoder.Reset();
  152. _characters.Clear();
  153. }
  154. }
  155. }