Browse Source

Fixed decoding of large tokens (over 16 bytes) in streaming text decoder

tags/v0.10.0
Martin Evans 1 year ago
parent
commit
98635a0d5a
4 changed files with 61 additions and 8 deletions
  1. +53
    -0
      LLama.Unittest/StreamingTextDecoderTests.cs
  2. +1
    -1
      LLama/Native/SafeLLamaContextHandle.cs
  3. +2
    -2
      LLama/Native/SafeLlamaModelHandle.cs
  4. +5
    -5
      LLama/StreamingTokenDecoder.cs

+ 53
- 0
LLama.Unittest/StreamingTextDecoderTests.cs View File

@@ -0,0 +1,53 @@
using System.Text;
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public class StreamingTextDecoderTests
: IDisposable
{
private readonly LLamaWeights _model;
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;

public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath);
_model = LLamaWeights.LoadFromFile(_params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void DecodesSimpleText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "The cat sat on the mat";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}

[Fact]
public void DecodesComplexText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "猫坐在垫子上 😀🤨🤐😏";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}
}

+ 1
- 1
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -194,7 +194,7 @@ namespace LLama.Native
/// <param name="token">Token to decode</param> /// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{ {
return ThrowIfDisposed().TokenToSpan(token, dest); return ThrowIfDisposed().TokenToSpan(token, dest);
} }


+ 2
- 2
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -126,10 +126,10 @@ namespace LLama.Native
/// <param name="token">Token to decode</param> /// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param> /// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns> /// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{ {
var length = NativeApi.llama_token_to_piece(this, token, dest); var length = NativeApi.llama_token_to_piece(this, token, dest);
return Math.Abs(length);
return (uint)Math.Abs(length);
} }


/// <summary> /// <summary>


+ 5
- 5
LLama/StreamingTokenDecoder.cs View File

@@ -113,19 +113,19 @@ namespace LLama
// Try to get bytes // Try to get bytes
var l = model.TokenToSpan(token, bytes); var l = model.TokenToSpan(token, bytes);


// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
// Check if the length was larger than the buffer. If so expand the buffer and try again
if (l > bytes.Length)
{ {
// Return the old array to the pool and get a new one // Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes); ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));


// Get bytes, this time it can't fail // Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes); l = model.TokenToSpan(token, bytes);
} }


Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
Debug.Assert(l <= bytes.Length);
return new Span<byte>(bytes, 0, (int)l);
} }
} }




Loading…
Cancel
Save