Browse Source

- Moved tokenization from `Utils.Tokenize` into `SafeLLamaContextHandle.Tokenize`, one less thing in `Utils`.

- Also refactored it to return an `int[]` instead of an `IEnumerable<int>`, solving the "multiple enumeration" problems at the source!
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
cd3cf2b77d
7 changed files with 51 additions and 18 deletions
  1. +1
    -1
      LLama/LLamaEmbedder.cs
  2. +3
    -3
      LLama/LLamaInstructExecutor.cs
  3. +2
    -2
      LLama/LLamaInteractExecutor.cs
  4. +2
    -3
      LLama/LLamaModel.cs
  5. +1
    -0
      LLama/Native/NativeApi.cs
  6. +40
    -0
      LLama/Native/SafeLLamaContextHandle.cs
  7. +2
    -9
      LLama/Utils.cs

+ 1
- 1
LLama/LLamaEmbedder.cs View File

@@ -55,7 +55,7 @@ namespace LLama
text = text.Insert(0, " ");
}

var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray();
var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));

// TODO(Rinne): deal with log of prompt



+ 3
- 3
LLama/LLamaInstructExecutor.cs View File

@@ -30,8 +30,8 @@ namespace LLama
public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n") : base(model)
{
_inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray();
_inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray();
_inp_pfx = _model.Tokenize(instructionPrefix, true);
_inp_sfx = _model.Tokenize(instructionSuffix, false);
_instructionPrefix = instructionPrefix;
}

@@ -133,7 +133,7 @@ namespace LLama

_embed_inps.AddRange(_inp_sfx);

args.RemainedTokens -= line_inp.Count();
args.RemainedTokens -= line_inp.Length;
}
}
/// <inheritdoc />


+ 2
- 2
LLama/LLamaInteractExecutor.cs View File

@@ -25,7 +25,7 @@ namespace LLama
/// <param name="model"></param>
public InteractiveExecutor(LLamaModel model) : base(model)
{
_llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
_llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding);
}

/// <inheritdoc />
@@ -114,7 +114,7 @@ namespace LLama
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Count();
args.RemainedTokens -= line_inp.Length;
}
}



+ 2
- 3
LLama/LLamaModel.cs View File

@@ -64,10 +64,9 @@ namespace LLama
/// <param name="text"></param>
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <returns></returns>
public IEnumerable<llama_token> Tokenize(string text, bool addBos = true)
public llama_token[] Tokenize(string text, bool addBos = true)
{
// TODO: reconsider whether to convert to array here.
return Utils.Tokenize(_ctx, text, addBos, _encoding);
return _ctx.Tokenize(text, addBos, _encoding);
}

/// <summary>


+ 1
- 0
LLama/Native/NativeApi.cs View File

@@ -218,6 +218,7 @@ namespace LLama.Native
/// </summary>
/// <param name="ctx"></param>
/// <param name="text"></param>
/// <param name="encoding"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>


+ 40
- 0
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,4 +1,6 @@
using System;
using System.Buffers;
using System.Text;
using LLama.Exceptions;

namespace LLama.Native
@@ -57,5 +59,43 @@ namespace LLama.Native

return new(ctx_ptr, model);
}

/// <summary>
/// Convert the given text into tokens
/// </summary>
/// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);

// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<int>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}

// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);

return result;
}
finally
{
ArrayPool<int>.Shared.Return(temporaryArray);
}
}
}
}

+ 2
- 9
LLama/Utils.cs View File

@@ -27,17 +27,10 @@ namespace LLama
}
}

[Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
{
var cnt = encoding.GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
return res.Take(n);
return ctx.Tokenize(text, add_bos, encoding);
}

public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)


Loading…
Cancel
Save