Browse Source

Added in the `special` parameter to the tokenizer (introduced in https://github.com/ggerganov/llama.cpp/pull/3538)

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
1f8c94e386
7 changed files with 22 additions and 17 deletions
  1. +2
    -2
      LLama.Unittest/LLamaContextTests.cs
  2. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  3. +4
    -4
      LLama.Unittest/TokenTests.cs
  4. +3
    -2
      LLama/LLamaContext.cs
  5. +5
    -3
      LLama/Native/NativeApi.cs
  6. +3
    -2
      LLama/Native/SafeLLamaContextHandle.cs
  7. +4
    -3
      LLama/Native/SafeLlamaModelHandle.cs

+ 2
- 2
LLama.Unittest/LLamaContextTests.cs View File

@@ -37,7 +37,7 @@ namespace LLama.Unittest
{ {
var tokens = _context.Tokenize("The quick brown fox", true); var tokens = _context.Tokenize("The quick brown fox", true);


Assert.Equal(new[] { 1, 1576, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
} }


[Fact] [Fact]
@@ -45,7 +45,7 @@ namespace LLama.Unittest
{ {
var tokens = _context.Tokenize("The quick brown fox", false); var tokens = _context.Tokenize("The quick brown fox", false);


Assert.Equal(new[] { 1576, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens);
} }


[Fact] [Fact]


+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -54,7 +54,7 @@ namespace LLama.Unittest
// with a modified context // with a modified context
var @params = new InferenceParams() var @params = new InferenceParams()
{ {
MaxTokens = 70,
MaxTokens = 65,
TokensKeep = question.Length, TokensKeep = question.Length,
}; };




+ 4
- 4
LLama.Unittest/TokenTests.cs View File

@@ -27,7 +27,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensEndWith() public void TokensEndWith()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString(new[] var result = tokens.TokensEndsWithAnyString(new[]
{ {
@@ -41,7 +41,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensEndSubstring() public void TokensEndSubstring()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)new[] var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{ {
@@ -53,7 +53,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensNotEndWith() public void TokensNotEndWith()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)new[] var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{ {
@@ -67,7 +67,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensNotEndWithNothing() public void TokensNotEndWithNothing()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result); Assert.False(result);


+ 3
- 2
LLama/LLamaContext.cs View File

@@ -92,10 +92,11 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="addBos">Whether to add a bos to the text.</param> /// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true)
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{ {
return _ctx.Tokenize(text, addBos, _encoding);
return _ctx.Tokenize(text, addBos, special, _encoding);
} }


/// <summary> /// <summary>


+ 5
- 3
LLama/Native/NativeApi.cs View File

@@ -284,10 +284,11 @@ namespace LLama.Native
/// <param name="tokens"></param> /// <param name="tokens"></param>
/// <param name="n_max_tokens"></param> /// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens. /// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned /// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns> /// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
{ {
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte) // Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text); var byteCount = encoding.GetByteCount(text);
@@ -307,7 +308,7 @@ namespace LLama.Native
// Do the actual tokenization // Do the actual tokenization
fixed (byte* arrayPtr = array) fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens) fixed (llama_token* tokensPtr = tokens)
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos);
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
} }
finally finally
{ {
@@ -454,11 +455,12 @@ namespace LLama.Native
/// <param name="tokens"></param> /// <param name="tokens"></param>
/// <param name="n_max_tokens"></param> /// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens. /// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned /// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns> /// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos);
public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);


/// <summary> /// <summary>
/// Register a callback to receive llama log messages /// Register a callback to receive llama log messages


+ 3
- 2
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -122,9 +122,10 @@ namespace LLama.Native
/// <param name="text">The text to tokenize</param> /// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param> /// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param> /// <param name="encoding">Encoding to use for the text</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{ {
ThrowIfDisposed(); ThrowIfDisposed();


@@ -140,7 +141,7 @@ namespace LLama.Native
try try
{ {
// Do the actual conversion // Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special);
if (n < 0) if (n < 0)
{ {
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +


+ 4
- 3
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -271,8 +271,9 @@ namespace LLama.Native
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="encoding"></param> /// <param name="encoding"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{ {
// Convert string to bytes, adding one extra byte to the end (null terminator) // Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text); var bytesCount = encoding.GetByteCount(text);
@@ -291,13 +292,13 @@ namespace LLama.Native
fixed (byte* bytesPtr = &bytes[0]) fixed (byte* bytesPtr = &bytes[0])
{ {
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos);
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);


// Tokenize again, this time outputting into an array of exactly the right size // Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count]; var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0]) fixed (int* tokensPtr = &tokens[0])
{ {
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos);
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens; return tokens;
} }
} }


Loading…
Cancel
Save