using System.Buffers; using System.Diagnostics; using System; using System.Collections.Generic; using System.Text; using LLama.Extensions; using LLama.Native; namespace LLama { /// /// Decodes a stream of tokens into a stream of characters /// public sealed class StreamingTokenDecoder { private readonly SafeLlamaModelHandle _weights; private readonly Decoder _decoder; private readonly List _characters = new(); /// /// The number of decoded characters waiting to be read /// public int AvailableCharacters => _characters.Count; #region constructors /// /// Create a new decoder /// /// Text encoding to use /// Model weights public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) : this(encoding, weights.NativeHandle) { } /// /// Create a new decoder /// /// Context to retrieve encoding and model weights from public StreamingTokenDecoder(LLamaContext context) : this(context.Encoding, context.NativeHandle) { } /// /// Create a new decoder /// /// Text encoding to use /// Context to retrieve model weights from public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) : this(encoding, context.ModelHandle) { } /// /// Create a new decoder /// /// Text encoding to use /// Models weights to use public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) { _weights = weights; _decoder = encoding.GetDecoder(); } #endregion /// /// Add a single token to the decoder /// /// public void Add(LLamaToken token) { var charsArr = ArrayPool.Shared.Rent(16); var bytesArr = ArrayPool.Shared.Rent(16); try { // Convert this token into bytes var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; // Convert those bytes into characters var bytesOffset = 0; var completed = false; while (!completed) { // Decode some of the bytes into the temp char buffer. Keep doing this // until all bytes have been consumed _decoder.Convert( bytesArr, bytesOffset, bytesAvailable, charsArr, 0, charsArr.Length, false, out var bytesUsed, out var charsUsed, out completed ); bytesOffset += bytesUsed; bytesAvailable -= bytesUsed; // Add the decoded characters to the output buffer _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); } } finally { ArrayPool.Shared.Return(charsArr); ArrayPool.Shared.Return(bytesArr); } return; // Converts a single token into bytes, using the `bytes` array as temporary storage. // If the `bytes` array is too small it will get a larger one from the ArrayPool. static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model) { // Try to get bytes var l = model.TokenToSpan(token, bytes); // Check if the length was larger than the buffer. If so expand the buffer and try again if (l > bytes.Length) { // Return the old array to the pool and get a new one ArrayPool.Shared.Return(bytes); bytes = ArrayPool.Shared.Rent((int)(l * 2)); // Get bytes, this time it can't fail l = model.TokenToSpan(token, bytes); } Debug.Assert(l <= bytes.Length); return new Span(bytes, 0, (int)l); } } /// /// Add a single token to the decoder /// /// public void Add(int token) { Add((LLamaToken)token); } /// /// Add all tokens in the given enumerable /// /// public void AddRange(T tokens) where T : IEnumerable { foreach (var item in tokens) Add((int)item); } /// /// Add all tokens in the given span /// /// public void AddRange(ReadOnlySpan tokens) { foreach (var item in tokens) Add(item); } /// /// Read all decoded characters and clear the buffer /// /// public void Read(List dest) { dest.AddRange(_characters); _characters.Clear(); } /// /// Read all decoded characters as a string and clear the buffer /// /// public string Read() { if (_characters.Count == 0) return ""; var str = string.Join("", _characters); _characters.Clear(); return str; } /// /// Set the decoder back to its initial state /// public void Reset() { _decoder.Reset(); _characters.Clear(); } } }