| @@ -0,0 +1,53 @@ | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using Xunit.Abstractions; | |||||
| namespace LLama.Unittest; | |||||
| public class StreamingTextDecoderTests | |||||
| : IDisposable | |||||
| { | |||||
| private readonly LLamaWeights _model; | |||||
| private readonly ITestOutputHelper _testOutputHelper; | |||||
| private readonly ModelParams _params; | |||||
| public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper) | |||||
| { | |||||
| _testOutputHelper = testOutputHelper; | |||||
| _params = new ModelParams(Constants.ModelPath); | |||||
| _model = LLamaWeights.LoadFromFile(_params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void DecodesSimpleText() | |||||
| { | |||||
| var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); | |||||
| const string text = "The cat sat on the mat"; | |||||
| var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8); | |||||
| foreach (var lLamaToken in tokens) | |||||
| decoder.Add(lLamaToken); | |||||
| Assert.Equal(text, decoder.Read().Trim()); | |||||
| } | |||||
| [Fact] | |||||
| public void DecodesComplexText() | |||||
| { | |||||
| var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); | |||||
| const string text = "猫坐在垫子上 😀🤨🤐😏"; | |||||
| var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8); | |||||
| foreach (var lLamaToken in tokens) | |||||
| decoder.Add(lLamaToken); | |||||
| Assert.Equal(text, decoder.Read().Trim()); | |||||
| } | |||||
| } | |||||
| @@ -194,7 +194,7 @@ namespace LLama.Native | |||||
| /// <param name="token">Token to decode</param> | /// <param name="token">Token to decode</param> | ||||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | ||||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | ||||
| public int TokenToSpan(LLamaToken token, Span<byte> dest) | |||||
| public uint TokenToSpan(LLamaToken token, Span<byte> dest) | |||||
| { | { | ||||
| return ThrowIfDisposed().TokenToSpan(token, dest); | return ThrowIfDisposed().TokenToSpan(token, dest); | ||||
| } | } | ||||
| @@ -126,10 +126,10 @@ namespace LLama.Native | |||||
| /// <param name="token">Token to decode</param> | /// <param name="token">Token to decode</param> | ||||
| /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> | ||||
| /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> | ||||
| public int TokenToSpan(LLamaToken token, Span<byte> dest) | |||||
| public uint TokenToSpan(LLamaToken token, Span<byte> dest) | |||||
| { | { | ||||
| var length = NativeApi.llama_token_to_piece(this, token, dest); | var length = NativeApi.llama_token_to_piece(this, token, dest); | ||||
| return Math.Abs(length); | |||||
| return (uint)Math.Abs(length); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -113,19 +113,19 @@ namespace LLama | |||||
| // Try to get bytes | // Try to get bytes | ||||
| var l = model.TokenToSpan(token, bytes); | var l = model.TokenToSpan(token, bytes); | ||||
| // Negative length indicates that the output was too small. Expand it to twice that size and try again. | |||||
| if (l < 0) | |||||
| // Check if the length was larger than the buffer. If so expand the buffer and try again | |||||
| if (l > bytes.Length) | |||||
| { | { | ||||
| // Return the old array to the pool and get a new one | // Return the old array to the pool and get a new one | ||||
| ArrayPool<byte>.Shared.Return(bytes); | ArrayPool<byte>.Shared.Return(bytes); | ||||
| bytes = ArrayPool<byte>.Shared.Rent(-l * 2); | |||||
| bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2)); | |||||
| // Get bytes, this time it can't fail | // Get bytes, this time it can't fail | ||||
| l = model.TokenToSpan(token, bytes); | l = model.TokenToSpan(token, bytes); | ||||
| } | } | ||||
| Debug.Assert(l >= 0); | |||||
| return new Span<byte>(bytes, 0, l); | |||||
| Debug.Assert(l <= bytes.Length); | |||||
| return new Span<byte>(bytes, 0, (int)l); | |||||
| } | } | ||||
| } | } | ||||