diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 24b6ee80..a74f11ee 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -55,7 +55,7 @@ namespace LLama text = text.Insert(0, " "); } - var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray(); + var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); // TODO(Rinne): deal with log of prompt diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e055c147..ae7035c9 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -30,8 +30,8 @@ namespace LLama public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n") : base(model) { - _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray(); - _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray(); + _inp_pfx = _model.Tokenize(instructionPrefix, true); + _inp_sfx = _model.Tokenize(instructionSuffix, false); _instructionPrefix = instructionPrefix; } @@ -133,7 +133,7 @@ namespace LLama _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Count(); + args.RemainedTokens -= line_inp.Length; } } /// diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index f5c1583e..3b0b13be 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -25,7 +25,7 @@ namespace LLama /// public InteractiveExecutor(LLamaModel model) : base(model) { - _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray(); + _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); } /// @@ -114,7 +114,7 @@ namespace LLama } var line_inp = _model.Tokenize(text, false); _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Count(); + args.RemainedTokens -= line_inp.Length; } } diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2bd31199..4bc18c1e 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -64,10 +64,9 @@ namespace LLama /// /// Whether to add a bos to the text. /// - public IEnumerable Tokenize(string text, bool addBos = true) + public llama_token[] Tokenize(string text, bool addBos = true) { - // TODO: reconsider whether to convert to array here. - return Utils.Tokenize(_ctx, text, addBos, _encoding); + return _ctx.Tokenize(text, addBos, _encoding); } /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 4e0ac2a2..5857b590 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -218,6 +218,7 @@ namespace LLama.Native /// /// /// + /// /// /// /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index ab102228..9e81de69 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,4 +1,6 @@ using System; +using System.Buffers; +using System.Text; using LLama.Exceptions; namespace LLama.Native @@ -57,5 +59,43 @@ namespace LLama.Native return new(ctx_ptr, model); } + + /// + /// Convert the given text into tokens + /// + /// The text to tokenize + /// Whether the "BOS" token should be added + /// Encoding to use for the text + /// + /// + public int[] Tokenize(string text, bool add_bos, Encoding encoding) + { + // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't + // possibly be more than this. + var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); + + // "Rent" an array to write results into (avoiding an allocation of a large array) + var temporaryArray = ArrayPool.Shared.Rent(count); + try + { + // Do the actual conversion + var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); + if (n < 0) + { + throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + + "specify the encoding."); + } + + // Copy the results from the rented into an array which is exactly the right size + var result = new int[n]; + Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); + + return result; + } + finally + { + ArrayPool.Shared.Return(temporaryArray); + } + } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 391a5cc1..7a1f5f42 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -27,17 +27,10 @@ namespace LLama } } + [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) { - var cnt = encoding.GetByteCount(text); - llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; - int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); - if (n < 0) - { - throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + - "specify the encoding."); - } - return res.Take(n); + return ctx.Tokenize(text, add_bos, encoding); } public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length)