- Also refactored it to return an `int[]` instead of an `IEnumerable<int>`, solving the "multiple enumeration" problems at the source!tags/v0.5.1
| @@ -55,7 +55,7 @@ namespace LLama | |||||
| text = text.Insert(0, " "); | text = text.Insert(0, " "); | ||||
| } | } | ||||
| var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray(); | |||||
| var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); | |||||
| // TODO(Rinne): deal with log of prompt | // TODO(Rinne): deal with log of prompt | ||||
| @@ -30,8 +30,8 @@ namespace LLama | |||||
| public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", | public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", | ||||
| string instructionSuffix = "\n\n### Response:\n\n") : base(model) | string instructionSuffix = "\n\n### Response:\n\n") : base(model) | ||||
| { | { | ||||
| _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray(); | |||||
| _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray(); | |||||
| _inp_pfx = _model.Tokenize(instructionPrefix, true); | |||||
| _inp_sfx = _model.Tokenize(instructionSuffix, false); | |||||
| _instructionPrefix = instructionPrefix; | _instructionPrefix = instructionPrefix; | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ namespace LLama | |||||
| _embed_inps.AddRange(_inp_sfx); | _embed_inps.AddRange(_inp_sfx); | ||||
| args.RemainedTokens -= line_inp.Count(); | |||||
| args.RemainedTokens -= line_inp.Length; | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -25,7 +25,7 @@ namespace LLama | |||||
| /// <param name="model"></param> | /// <param name="model"></param> | ||||
| public InteractiveExecutor(LLamaModel model) : base(model) | public InteractiveExecutor(LLamaModel model) : base(model) | ||||
| { | { | ||||
| _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray(); | |||||
| _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| @@ -114,7 +114,7 @@ namespace LLama | |||||
| } | } | ||||
| var line_inp = _model.Tokenize(text, false); | var line_inp = _model.Tokenize(text, false); | ||||
| _embed_inps.AddRange(line_inp); | _embed_inps.AddRange(line_inp); | ||||
| args.RemainedTokens -= line_inp.Count(); | |||||
| args.RemainedTokens -= line_inp.Length; | |||||
| } | } | ||||
| } | } | ||||
| @@ -64,10 +64,9 @@ namespace LLama | |||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="addBos">Whether to add a bos to the text.</param> | /// <param name="addBos">Whether to add a bos to the text.</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public IEnumerable<llama_token> Tokenize(string text, bool addBos = true) | |||||
| public llama_token[] Tokenize(string text, bool addBos = true) | |||||
| { | { | ||||
| // TODO: reconsider whether to convert to array here. | |||||
| return Utils.Tokenize(_ctx, text, addBos, _encoding); | |||||
| return _ctx.Tokenize(text, addBos, _encoding); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -218,6 +218,7 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="encoding"></param> | |||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| /// <param name="n_max_tokens"></param> | /// <param name="n_max_tokens"></param> | ||||
| /// <param name="add_bos"></param> | /// <param name="add_bos"></param> | ||||
| @@ -1,4 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| namespace LLama.Native | namespace LLama.Native | ||||
| @@ -57,5 +59,43 @@ namespace LLama.Native | |||||
| return new(ctx_ptr, model); | return new(ctx_ptr, model); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Convert the given text into tokens | |||||
| /// </summary> | |||||
| /// <param name="text">The text to tokenize</param> | |||||
| /// <param name="add_bos">Whether the "BOS" token should be added</param> | |||||
| /// <param name="encoding">Encoding to use for the text</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||||
| { | |||||
| // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't | |||||
| // possibly be more than this. | |||||
| var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); | |||||
| // "Rent" an array to write results into (avoiding an allocation of a large array) | |||||
| var temporaryArray = ArrayPool<int>.Shared.Rent(count); | |||||
| try | |||||
| { | |||||
| // Do the actual conversion | |||||
| var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); | |||||
| if (n < 0) | |||||
| { | |||||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | |||||
| "specify the encoding."); | |||||
| } | |||||
| // Copy the results from the rented into an array which is exactly the right size | |||||
| var result = new int[n]; | |||||
| Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); | |||||
| return result; | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<int>.Shared.Return(temporaryArray); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,17 +27,10 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] | |||||
| public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) | public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) | ||||
| { | { | ||||
| var cnt = encoding.GetByteCount(text); | |||||
| llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; | |||||
| int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); | |||||
| if (n < 0) | |||||
| { | |||||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | |||||
| "specify the encoding."); | |||||
| } | |||||
| return res.Take(n); | |||||
| return ctx.Tokenize(text, add_bos, encoding); | |||||
| } | } | ||||
| public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | ||||