using System.Buffers;
using System.Diagnostics;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Extensions;
using LLama.Native;
namespace LLama
{
///
/// Decodes a stream of tokens into a stream of characters
///
public sealed class StreamingTokenDecoder
{
private readonly SafeLlamaModelHandle _weights;
private readonly Decoder _decoder;
private readonly List _characters = new();
///
/// The number of decoded characters waiting to be read
///
public int AvailableCharacters => _characters.Count;
#region constructors
///
/// Create a new decoder
///
/// Text encoding to use
/// Model weights
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
{
}
///
/// Create a new decoder
///
/// Context to retrieve encoding and model weights from
public StreamingTokenDecoder(LLamaContext context)
: this(context.Encoding, context.NativeHandle)
{
}
///
/// Create a new decoder
///
/// Text encoding to use
/// Context to retrieve model weights from
public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
: this(encoding, context.ModelHandle)
{
}
///
/// Create a new decoder
///
/// Text encoding to use
/// Models weights to use
public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
{
_weights = weights;
_decoder = encoding.GetDecoder();
}
#endregion
///
/// Add a single token to the decoder
///
///
public void Add(LLamaToken token)
{
var charsArr = ArrayPool.Shared.Rent(16);
var bytesArr = ArrayPool.Shared.Rent(16);
try
{
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
// Convert those bytes into characters
var bytesOffset = 0;
var completed = false;
while (!completed)
{
// Decode some of the bytes into the temp char buffer. Keep doing this
// until all bytes have been consumed
_decoder.Convert(
bytesArr, bytesOffset, bytesAvailable,
charsArr, 0, charsArr.Length,
false,
out var bytesUsed, out var charsUsed, out completed
);
bytesOffset += bytesUsed;
bytesAvailable -= bytesUsed;
// Add the decoded characters to the output buffer
_characters.AddSpan(charsArr.AsSpan(0, charsUsed));
}
}
finally
{
ArrayPool.Shared.Return(charsArr);
ArrayPool.Shared.Return(bytesArr);
}
return;
// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
// Check if the length was larger than the buffer. If so expand the buffer and try again
if (l > bytes.Length)
{
// Return the old array to the pool and get a new one
ArrayPool.Shared.Return(bytes);
bytes = ArrayPool.Shared.Rent((int)(l * 2));
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
Debug.Assert(l <= bytes.Length);
return new Span(bytes, 0, (int)l);
}
}
///
/// Add a single token to the decoder
///
///
public void Add(int token)
{
Add((LLamaToken)token);
}
///
/// Add all tokens in the given enumerable
///
///
public void AddRange(T tokens)
where T : IEnumerable
{
foreach (var item in tokens)
Add((int)item);
}
///
/// Add all tokens in the given span
///
///
public void AddRange(ReadOnlySpan tokens)
{
foreach (var item in tokens)
Add(item);
}
///
/// Read all decoded characters and clear the buffer
///
///
public void Read(List dest)
{
dest.AddRange(_characters);
_characters.Clear();
}
///
/// Read all decoded characters as a string and clear the buffer
///
///
public string Read()
{
if (_characters.Count == 0)
return "";
var str = string.Join("", _characters);
_characters.Clear();
return str;
}
///
/// Set the decoder back to its initial state
///
public void Reset()
{
_decoder.Reset();
_characters.Clear();
}
}
}