using System.Buffers; using System.Diagnostics; using System; using System.Collections.Generic; using System.Text; using LLama.Extensions; using LLama.Native; namespace LLama.Transform { /// /// Decodes a stream of tokens into a stream of characters /// public sealed class StreamingTokenDecoder { private readonly SafeLlamaModelHandle? _weights; private readonly Decoder _decoder; private readonly List _characters = new(); /// /// The number of decoded characters waiting to be read /// public int AvailableCharacters => _characters.Count; #region constructors /// /// Create a new decoder /// /// Text encoding to use /// Model weights public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null) : this(encoding, weights?.NativeHandle) { } /// /// Create a new decoder /// /// Context to retrieve encoding and model weights from public StreamingTokenDecoder(LLamaContext context) : this(context.Encoding, context.NativeHandle) { } /// /// Create a new decoder /// /// Text encoding to use /// Context to retrieve model weights from public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) : this(encoding, context.ModelHandle) { } /// /// Create a new decoder /// /// Text encoding to use /// Models weights to use public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) { _weights = weights; _decoder = encoding.GetDecoder(); } #endregion /// /// Add a single token to the decoder /// /// public void Add(int token, SafeLlamaModelHandle? weights = null) { weights ??= _weights; if(weights is null) { throw new NullReferenceException("No weights provided for StreamingTokenDecoder."); } var charsArr = ArrayPool.Shared.Rent(16); var bytesArr = ArrayPool.Shared.Rent(16); try { // Convert this token into bytes var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length; // Convert those bytes into characters var bytesOffset = 0; var completed = false; while (!completed) { // Decode some of the bytes into the temp char buffer. Keep doing this // until all bytes have been consumed _decoder.Convert( bytesArr, bytesOffset, bytesAvailable, charsArr, 0, charsArr.Length, false, out var bytesUsed, out var charsUsed, out completed ); bytesOffset += bytesUsed; bytesAvailable -= bytesUsed; // Add the decoded characters to the output buffer _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); } } finally { ArrayPool.Shared.Return(charsArr); ArrayPool.Shared.Return(bytesArr); } return; // Converts a single token into bytes, using the `bytes` array as temporary storage. // If the `bytes` array is too small it will get a larger one from the ArrayPool. static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) { // Try to get bytes var l = model.TokenToSpan(token, bytes); // Negative length indicates that the output was too small. Expand it to twice that size and try again. if (l < 0) { // Return the old array to the pool and get a new one ArrayPool.Shared.Return(bytes); bytes = ArrayPool.Shared.Rent(-l * 2); // Get bytes, this time it can't fail l = model.TokenToSpan(token, bytes); } Debug.Assert(l >= 0); return new Span(bytes, 0, l); } } /// /// Add all tokens in the given enumerable /// /// public void AddRange(IEnumerable tokens, SafeLlamaModelHandle? weights = null) { foreach (var item in tokens) Add(item, weights); } /// /// Read all decoded characters and clear the buffer /// /// public void Read(List dest) { dest.AddRange(_characters); _characters.Clear(); } /// /// Read all decoded characters as a string and clear the buffer /// /// public string Read() { if (_characters.Count == 0) return ""; var str = string.Join("", _characters); _characters.Clear(); return str; } /// /// Set the decoder back to its initial state /// public void Reset() { _decoder.Reset(); _characters.Clear(); } } }