|
|
|
@@ -2,6 +2,7 @@ |
|
|
|
using System.Buffers; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Diagnostics; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System.Text; |
|
|
|
using LLama.Exceptions; |
|
|
|
using LLama.Extensions; |
|
|
|
@@ -118,66 +119,6 @@ namespace LLama.Native |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Convert a single llama token into a string |
|
|
|
/// </summary> |
|
|
|
/// <param name="llama_token"></param> |
|
|
|
/// <param name="encoding">Encoding to use to decode the bytes into a string</param> |
|
|
|
/// <returns></returns> |
|
|
|
public string TokenToString(int llama_token, Encoding encoding) |
|
|
|
{ |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); |
|
|
|
if (length == 0) |
|
|
|
return ""; |
|
|
|
|
|
|
|
Span<byte> bytes = stackalloc byte[-length]; |
|
|
|
|
|
|
|
fixed (byte* bytePtr = bytes) |
|
|
|
{ |
|
|
|
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); |
|
|
|
Debug.Assert(written == bytes.Length); |
|
|
|
|
|
|
|
return encoding.GetString(bytePtr, bytes.Length); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Append a single llama token to a string builder |
|
|
|
/// </summary> |
|
|
|
/// <param name="llama_token">Token to decode</param> |
|
|
|
/// <param name="encoding"></param> |
|
|
|
/// <param name="dest">string builder to append the result to</param> |
|
|
|
public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest) |
|
|
|
{ |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); |
|
|
|
if (length == 0) |
|
|
|
return; |
|
|
|
|
|
|
|
Span<byte> bytes = stackalloc byte[-length]; |
|
|
|
fixed (byte* bytePtr = bytes) |
|
|
|
{ |
|
|
|
// Decode into bytes |
|
|
|
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); |
|
|
|
Debug.Assert(written == bytes.Length); |
|
|
|
|
|
|
|
// Decode into chars |
|
|
|
var charCount = encoding.GetCharCount(bytePtr, bytes.Length); |
|
|
|
Span<char> chars = stackalloc char[charCount]; |
|
|
|
fixed (char* charPtr = chars) |
|
|
|
encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length); |
|
|
|
|
|
|
|
// Write it to the output |
|
|
|
for (var i = 0; i < chars.Length; i++) |
|
|
|
dest.Append(chars[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Convert a sequence of tokens into characters. |
|
|
|
/// </summary> |
|
|
|
@@ -192,42 +133,52 @@ namespace LLama.Native |
|
|
|
{ |
|
|
|
// Rent an array to detokenize into |
|
|
|
var tokenBytesArr = ArrayPool<byte>.Shared.Rent(16); |
|
|
|
var tokenCharsArr = ArrayPool<char>.Shared.Rent(16); |
|
|
|
try |
|
|
|
|
|
|
|
// Convert all of the tokens into bytes |
|
|
|
var bytes = new List<byte>(); |
|
|
|
foreach (var token in tokens) |
|
|
|
{ |
|
|
|
var totalCharacters = 0; |
|
|
|
var unused = dest; |
|
|
|
var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); |
|
|
|
foreach (var tokenByte in tokenBytes) |
|
|
|
bytes.Add(tokenByte); |
|
|
|
} |
|
|
|
|
|
|
|
for (var i = tokens.Count - 1; i >= 0; i--) |
|
|
|
// Extract a span from the list |
|
|
|
var bytesSpan = |
|
|
|
#if NETSTANDARD2_0 |
|
|
|
bytes.ToArray().AsSpan(); |
|
|
|
#else |
|
|
|
CollectionsMarshal.AsSpan(bytes); |
|
|
|
#endif |
|
|
|
|
|
|
|
// Check how many characters these bytes represent. If there's not enough space in the |
|
|
|
// output array we need to handle that. |
|
|
|
var characterCount = encoding.GetCharCount(bytesSpan); |
|
|
|
if (characterCount > dest.Length) |
|
|
|
{ |
|
|
|
var bigChars = ArrayPool<char>.Shared.Rent(characterCount); |
|
|
|
try |
|
|
|
{ |
|
|
|
var token = tokens[i]; |
|
|
|
|
|
|
|
// Get bytes for this token |
|
|
|
var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); |
|
|
|
|
|
|
|
// Get chars for this token |
|
|
|
var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding); |
|
|
|
encoding.GetChars(bytesSpan, bigChars); |
|
|
|
var charSlice = bigChars |
|
|
|
.AsSpan(0, characterCount) |
|
|
|
.Slice(characterCount - dest.Length); |
|
|
|
|
|
|
|
// Trim down number of characters if there are too many |
|
|
|
if (tokenChars.Length > unused.Length) |
|
|
|
tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length); |
|
|
|
|
|
|
|
// Copy characters |
|
|
|
tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length)); |
|
|
|
unused = unused.Slice(0, unused.Length - tokenChars.Length); |
|
|
|
totalCharacters += tokenChars.Length; |
|
|
|
|
|
|
|
// Break out if we've run out of space |
|
|
|
if (unused.Length == 0) |
|
|
|
break; |
|
|
|
charSlice.CopyTo(dest); |
|
|
|
return dest; |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
ArrayPool<char>.Shared.Return(bigChars); |
|
|
|
} |
|
|
|
|
|
|
|
return dest.Slice(dest.Length - totalCharacters, totalCharacters); |
|
|
|
//todo: handle dest span too small |
|
|
|
throw new NotImplementedException(); |
|
|
|
} |
|
|
|
finally |
|
|
|
else |
|
|
|
{ |
|
|
|
ArrayPool<byte>.Shared.Return(tokenBytesArr); |
|
|
|
ArrayPool<char>.Shared.Return(tokenCharsArr); |
|
|
|
var charCount = encoding.GetChars(bytes.ToArray(), dest); |
|
|
|
return dest.Slice(0, charCount); |
|
|
|
} |
|
|
|
|
|
|
|
// vvv Local Functions vvv |
|
|
|
@@ -250,19 +201,6 @@ namespace LLama.Native |
|
|
|
Debug.Assert(l >= 0); |
|
|
|
return new Span<byte>(bytes, 0, l); |
|
|
|
} |
|
|
|
|
|
|
|
static Span<char> BytesToChars(ref char[] chars, ReadOnlySpan<byte> bytes, Encoding encoding) |
|
|
|
{ |
|
|
|
var count = encoding.GetCharCount(bytes); |
|
|
|
if (count > chars.Length) |
|
|
|
{ |
|
|
|
ArrayPool<char>.Shared.Return(chars); |
|
|
|
chars = ArrayPool<char>.Shared.Rent(count * 2); |
|
|
|
} |
|
|
|
|
|
|
|
encoding.GetChars(bytes, chars); |
|
|
|
return chars.AsSpan(0, count); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -304,7 +242,7 @@ namespace LLama.Native |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
#endregion |
|
|
|
#endregion |
|
|
|
|
|
|
|
#region context |
|
|
|
/// <summary> |
|
|
|
|