diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 24b6ee80..a74f11ee 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -55,7 +55,7 @@ namespace LLama
text = text.Insert(0, " ");
}
- var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray();
+ var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));
// TODO(Rinne): deal with log of prompt
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index e055c147..ae7035c9 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -30,8 +30,8 @@ namespace LLama
public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n") : base(model)
{
- _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray();
- _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray();
+ _inp_pfx = _model.Tokenize(instructionPrefix, true);
+ _inp_sfx = _model.Tokenize(instructionSuffix, false);
_instructionPrefix = instructionPrefix;
}
@@ -133,7 +133,7 @@ namespace LLama
_embed_inps.AddRange(_inp_sfx);
- args.RemainedTokens -= line_inp.Count();
+ args.RemainedTokens -= line_inp.Length;
}
}
///
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index f5c1583e..3b0b13be 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -25,7 +25,7 @@ namespace LLama
///
public InteractiveExecutor(LLamaModel model) : base(model)
{
- _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
+ _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding);
}
///
@@ -114,7 +114,7 @@ namespace LLama
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
- args.RemainedTokens -= line_inp.Count();
+ args.RemainedTokens -= line_inp.Length;
}
}
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 2bd31199..4bc18c1e 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -64,10 +64,9 @@ namespace LLama
///
/// Whether to add a bos to the text.
///
- public IEnumerable Tokenize(string text, bool addBos = true)
+ public llama_token[] Tokenize(string text, bool addBos = true)
{
- // TODO: reconsider whether to convert to array here.
- return Utils.Tokenize(_ctx, text, addBos, _encoding);
+ return _ctx.Tokenize(text, addBos, _encoding);
}
///
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 4e0ac2a2..5857b590 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -218,6 +218,7 @@ namespace LLama.Native
///
///
///
+ ///
///
///
///
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index ab102228..9e81de69 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,4 +1,6 @@
using System;
+using System.Buffers;
+using System.Text;
using LLama.Exceptions;
namespace LLama.Native
@@ -57,5 +59,43 @@ namespace LLama.Native
return new(ctx_ptr, model);
}
+
+ ///
+ /// Convert the given text into tokens
+ ///
+ /// The text to tokenize
+ /// Whether the "BOS" token should be added
+ /// Encoding to use for the text
+ ///
+ ///
+ public int[] Tokenize(string text, bool add_bos, Encoding encoding)
+ {
+ // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
+ // possibly be more than this.
+ var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
+
+ // "Rent" an array to write results into (avoiding an allocation of a large array)
+ var temporaryArray = ArrayPool.Shared.Rent(count);
+ try
+ {
+ // Do the actual conversion
+ var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
+ if (n < 0)
+ {
+ throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
+ "specify the encoding.");
+ }
+
+ // Copy the results from the rented into an array which is exactly the right size
+ var result = new int[n];
+ Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
+
+ return result;
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(temporaryArray);
+ }
+ }
}
}
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 391a5cc1..7a1f5f42 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -27,17 +27,10 @@ namespace LLama
}
}
+ [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
{
- var cnt = encoding.GetByteCount(text);
- llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
- int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
- if (n < 0)
- {
- throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
- "specify the encoding.");
- }
- return res.Take(n);
+ return ctx.Tokenize(text, add_bos, encoding);
}
public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length)