| @@ -7,6 +7,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -463,12 +464,12 @@ namespace LLama | |||||
| /// <param name="text">The utf-8 encoded string to tokenize.</param> | /// <param name="text">The utf-8 encoded string to tokenize.</param> | ||||
| /// <returns>A list of tokens.</returns> | /// <returns>A list of tokens.</returns> | ||||
| /// <exception cref="RuntimeError">If the tokenization failed.</exception> | /// <exception cref="RuntimeError">If the tokenization failed.</exception> | ||||
| public List<llama_token> Tokenize(string text) | |||||
| public List<llama_token> Tokenize(string text, string encoding = "UTF-8") | |||||
| { | { | ||||
| Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); | Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); | ||||
| var n_ctx = NativeApi.llama_n_ctx(_ctx); | var n_ctx = NativeApi.llama_n_ctx(_ctx); | ||||
| var tokens = new llama_token[n_ctx]; | var tokens = new llama_token[n_ctx]; | ||||
| var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true); | |||||
| var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true); | |||||
| if (n_tokens < 0) | if (n_tokens < 0) | ||||
| { | { | ||||
| throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}"); | throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}"); | ||||
| @@ -176,8 +176,27 @@ namespace LLama.Native | |||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName)] | |||||
| public static extern int llama_tokenize(SafeLLamaContextHandle ctx, string text, llama_token[] tokens, int n_max_tokens, bool add_bos); | |||||
| public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) | |||||
| { | |||||
| var bytes = encoding.GetBytes(text); | |||||
| sbyte[] data = new sbyte[bytes.Length]; | |||||
| for(int i = 0; i < bytes.Length; i++) | |||||
| { | |||||
| data[i] = (sbyte)bytes[i]; | |||||
| //if (bytes[i] < 128) | |||||
| //{ | |||||
| // data[i] = (sbyte)bytes[i]; | |||||
| //} | |||||
| //else | |||||
| //{ | |||||
| // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1); | |||||
| //} | |||||
| } | |||||
| return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos); | |||||
| } | |||||
| [DllImport(libraryName, EntryPoint = "llama_tokenize")] | |||||
| public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos); | |||||
| [DllImport(libraryName)] | [DllImport(libraryName)] | ||||
| public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); | ||||
| @@ -52,11 +52,12 @@ namespace LLama | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding) | |||||
| public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName) | |||||
| { | { | ||||
| var cnt = Encoding.GetEncoding(encoding).GetByteCount(text); | |||||
| var encoding = Encoding.GetEncoding(encodingName); | |||||
| var cnt = encoding.GetByteCount(text); | |||||
| llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; | llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; | ||||
| int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos); | |||||
| int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); | |||||
| if(n < 0) | if(n < 0) | ||||
| { | { | ||||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | ||||