using System.Buffers;
using System.Diagnostics;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Extensions;
using LLama.Native;
namespace LLama.Transform
{
///
/// Decodes a stream of tokens into a stream of characters
///
public sealed class StreamingTokenDecoder
{
private readonly SafeLlamaModelHandle? _weights;
private readonly Decoder _decoder;
private readonly List _characters = new();
///
/// The number of decoded characters waiting to be read
///
public int AvailableCharacters => _characters.Count;
#region constructors
///
/// Create a new decoder
///
/// Text encoding to use
/// Model weights
public StreamingTokenDecoder(Encoding encoding, LLamaWeights? weights = null)
: this(encoding, weights?.NativeHandle)
{
}
///
/// Create a new decoder
///
/// Context to retrieve encoding and model weights from
public StreamingTokenDecoder(LLamaContext context)
: this(context.Encoding, context.NativeHandle)
{
}
///
/// Create a new decoder
///
/// Text encoding to use
/// Context to retrieve model weights from
public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
: this(encoding, context.ModelHandle)
{
}
///
/// Create a new decoder
///
/// Text encoding to use
/// Models weights to use
public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
{
_weights = weights;
_decoder = encoding.GetDecoder();
}
#endregion
///
/// Add a single token to the decoder
///
///
public void Add(int token, SafeLlamaModelHandle? weights = null)
{
weights ??= _weights;
if(weights is null)
{
throw new NullReferenceException("No weights provided for StreamingTokenDecoder.");
}
var charsArr = ArrayPool.Shared.Rent(16);
var bytesArr = ArrayPool.Shared.Rent(16);
try
{
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, weights).Length;
// Convert those bytes into characters
var bytesOffset = 0;
var completed = false;
while (!completed)
{
// Decode some of the bytes into the temp char buffer. Keep doing this
// until all bytes have been consumed
_decoder.Convert(
bytesArr, bytesOffset, bytesAvailable,
charsArr, 0, charsArr.Length,
false,
out var bytesUsed, out var charsUsed, out completed
);
bytesOffset += bytesUsed;
bytesAvailable -= bytesUsed;
// Add the decoded characters to the output buffer
_characters.AddSpan(charsArr.AsSpan(0, charsUsed));
}
}
finally
{
ArrayPool.Shared.Return(charsArr);
ArrayPool.Shared.Return(bytesArr);
}
return;
// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
{
// Return the old array to the pool and get a new one
ArrayPool.Shared.Return(bytes);
bytes = ArrayPool.Shared.Rent(-l * 2);
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
Debug.Assert(l >= 0);
return new Span(bytes, 0, l);
}
}
///
/// Add all tokens in the given enumerable
///
///
public void AddRange(IEnumerable tokens, SafeLlamaModelHandle? weights = null)
{
foreach (var item in tokens)
Add(item, weights);
}
///
/// Read all decoded characters and clear the buffer
///
///
public void Read(List dest)
{
dest.AddRange(_characters);
_characters.Clear();
}
///
/// Read all decoded characters as a string and clear the buffer
///
///
public string Read()
{
if (_characters.Count == 0)
return "";
var str = string.Join("", _characters);
_characters.Clear();
return str;
}
///
/// Set the decoder back to its initial state
///
public void Reset()
{
_decoder.Reset();
_characters.Clear();
}
}
}