Browse Source

- Removed all `TokenToString` methods (it's never correct to use them, because sometimes one single character may be represented by multiple tokens).

- Built a new (hacky) `Detokenize` method which handles this
tags/v0.7.0
Martin Evans 2 years ago
parent
commit
efdf3d630c
6 changed files with 73 additions and 156 deletions
  1. +1
    -1
      LLama.Unittest/TokenTests.cs
  2. +2
    -26
      LLama/LLamaContext.cs
  3. +1
    -4
      LLama/LLamaExecutorBase.cs
  4. +1
    -1
      LLama/LLamaStatelessExecutor.cs
  5. +28
    -22
      LLama/Native/SafeLLamaContextHandle.cs
  6. +40
    -102
      LLama/Native/SafeLlamaModelHandle.cs

+ 1
- 1
LLama.Unittest/TokenTests.cs View File

@@ -79,7 +79,7 @@ public sealed class TokenTests
var strings = new[]
{
"Hello world",
"철수라는",
"철수",
"😀 😃 😄 😁 😆 😅 😂 😊 😇 🙂 ",
};



+ 2
- 26
LLama/LLamaContext.cs View File

@@ -102,13 +102,9 @@ namespace LLama
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
public string DeTokenize(IEnumerable<llama_token> tokens)
public string DeTokenize(IReadOnlyList<llama_token> tokens)
{
var sb = new StringBuilder();
foreach (var token in tokens)
NativeHandle.TokenToString(token, Encoding, sb);

return sb.ToString();
return NativeHandle.DeTokenize(tokens, Encoding);
}

/// <summary>
@@ -418,26 +414,6 @@ namespace LLama
}
#endregion

/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}

/// <summary>
/// Append a single token to a string builder
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(llama_token token, StringBuilder dest)
{
NativeHandle.TokenToString(token, Encoding, dest);
}

/// <inheritdoc />
public void Dispose()
{


+ 1
- 4
LLama/LLamaExecutorBase.cs View File

@@ -294,10 +294,7 @@ namespace LLama
await InferInternal(inferenceParams, args);

if (args.ReturnValue)
{
foreach (var id in _embeds)
yield return Context.TokenToString(id);
}
yield return Context.DeTokenize(_embeds);

var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })


+ 1
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -95,7 +95,7 @@ namespace LLama
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);

lastTokens.Add(id);
yield return Context.TokenToString(id);
yield return Context.DeTokenize(new [] { id }); //todo: not correct to return tokens one by one like this!

tokens.Clear();
tokens.Add(id);


+ 28
- 22
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;

@@ -159,38 +160,43 @@ namespace LLama.Native
}

/// <summary>
/// Convert a token into a string
/// Convert a single llama token into bytes
/// </summary>
/// <param name="token">Token to decode into a string</param>
/// <param name="encoding"></param>
/// <returns></returns>
public string TokenToString(int token, Encoding encoding)
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToString(token, encoding);
return ThrowIfDisposed().TokenToSpan(token, dest);
}

/// <summary>
/// Append a single llama token to a string builder
/// Convert a set of tokens into a string
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="tokens"></param>
/// <param name="encoding"></param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(int token, Encoding encoding, StringBuilder dest)
/// <returns></returns>
public string DeTokenize(IReadOnlyList<int> tokens, Encoding encoding)
{
ThrowIfDisposed().TokenToString(token, encoding, dest);
}
var chars = ArrayPool<char>.Shared.Rent(tokens.Count * 2);
try
{
var span = ThrowIfDisposed().TokensToSpan(tokens, chars.AsSpan(), encoding);
if (span.Length == 0)
return "";

/// <summary>
/// Convert a single llama token into bytes
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
unsafe
{
fixed (char* ptr = &span[0])
return new string(ptr, 0, span.Length);
}
}
finally
{
ArrayPool<char>.Shared.Return(chars);
}
}
#endregion
#endregion

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.


+ 40
- 102
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -2,6 +2,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
@@ -118,66 +119,6 @@ namespace LLama.Native
}
}

/// <summary>
/// Convert a single llama token into a string
/// </summary>
/// <param name="llama_token"></param>
/// <param name="encoding">Encoding to use to decode the bytes into a string</param>
/// <returns></returns>
public string TokenToString(int llama_token, Encoding encoding)
{
unsafe
{
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0)
return "";

Span<byte> bytes = stackalloc byte[-length];

fixed (byte* bytePtr = bytes)
{
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);

return encoding.GetString(bytePtr, bytes.Length);
}
}
}

/// <summary>
/// Append a single llama token to a string builder
/// </summary>
/// <param name="llama_token">Token to decode</param>
/// <param name="encoding"></param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest)
{
unsafe
{
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0)
return;

Span<byte> bytes = stackalloc byte[-length];
fixed (byte* bytePtr = bytes)
{
// Decode into bytes
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);

// Decode into chars
var charCount = encoding.GetCharCount(bytePtr, bytes.Length);
Span<char> chars = stackalloc char[charCount];
fixed (char* charPtr = chars)
encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length);

// Write it to the output
for (var i = 0; i < chars.Length; i++)
dest.Append(chars[i]);
}
}
}

/// <summary>
/// Convert a sequence of tokens into characters.
/// </summary>
@@ -192,42 +133,52 @@ namespace LLama.Native
{
// Rent an array to detokenize into
var tokenBytesArr = ArrayPool<byte>.Shared.Rent(16);
var tokenCharsArr = ArrayPool<char>.Shared.Rent(16);
try

// Convert all of the tokens into bytes
var bytes = new List<byte>();
foreach (var token in tokens)
{
var totalCharacters = 0;
var unused = dest;
var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this);
foreach (var tokenByte in tokenBytes)
bytes.Add(tokenByte);
}

for (var i = tokens.Count - 1; i >= 0; i--)
// Extract a span from the list
var bytesSpan =
#if NETSTANDARD2_0
bytes.ToArray().AsSpan();
#else
CollectionsMarshal.AsSpan(bytes);
#endif

// Check how many characters these bytes represent. If there's not enough space in the
// output array we need to handle that.
var characterCount = encoding.GetCharCount(bytesSpan);
if (characterCount > dest.Length)
{
var bigChars = ArrayPool<char>.Shared.Rent(characterCount);
try
{
var token = tokens[i];

// Get bytes for this token
var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this);

// Get chars for this token
var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding);
encoding.GetChars(bytesSpan, bigChars);
var charSlice = bigChars
.AsSpan(0, characterCount)
.Slice(characterCount - dest.Length);

// Trim down number of characters if there are too many
if (tokenChars.Length > unused.Length)
tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length);

// Copy characters
tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length));
unused = unused.Slice(0, unused.Length - tokenChars.Length);
totalCharacters += tokenChars.Length;

// Break out if we've run out of space
if (unused.Length == 0)
break;
charSlice.CopyTo(dest);
return dest;
}
finally
{
ArrayPool<char>.Shared.Return(bigChars);
}

return dest.Slice(dest.Length - totalCharacters, totalCharacters);
//todo: handle dest span too small
throw new NotImplementedException();
}
finally
else
{
ArrayPool<byte>.Shared.Return(tokenBytesArr);
ArrayPool<char>.Shared.Return(tokenCharsArr);
var charCount = encoding.GetChars(bytes.ToArray(), dest);
return dest.Slice(0, charCount);
}
// vvv Local Functions vvv
@@ -250,19 +201,6 @@ namespace LLama.Native
Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
}

static Span<char> BytesToChars(ref char[] chars, ReadOnlySpan<byte> bytes, Encoding encoding)
{
var count = encoding.GetCharCount(bytes);
if (count > chars.Length)
{
ArrayPool<char>.Shared.Return(chars);
chars = ArrayPool<char>.Shared.Rent(count * 2);
}

encoding.GetChars(bytes, chars);
return chars.AsSpan(0, count);
}
}

/// <summary>
@@ -304,7 +242,7 @@ namespace LLama.Native
}
}
}
#endregion
#endregion

#region context
/// <summary>


Loading…
Cancel
Save