Browse Source

fix: errors when input is not English or too long.

tags/untagged-54cee5cf55a360e770ad
Yaohui Liu 2 years ago
parent
commit
afedd3c949
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
5 changed files with 71 additions and 71 deletions
  1. +6
    -6
      LLama/ChatSession.cs
  2. +2
    -2
      LLama/IChatModel.cs
  3. +2
    -2
      LLama/LLamaEmbedder.cs
  4. +53
    -58
      LLama/LLamaModel.cs
  5. +8
    -3
      LLama/Utils.cs

+ 6
- 6
LLama/ChatSession.cs View File

@@ -16,11 +16,11 @@ namespace LLama
_model = model; _model = model;
} }


public IEnumerable<string> Chat(string text, string? prompt = null)
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{ {
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Human, text), DateTime.Now)); History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Human, text), DateTime.Now));
string totalResponse = ""; string totalResponse = "";
foreach(var response in _model.Chat(text, prompt))
foreach(var response in _model.Chat(text, prompt, encoding))
{ {
totalResponse += response; totalResponse += response;
yield return response; yield return response;
@@ -28,15 +28,15 @@ namespace LLama
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Assistant, totalResponse), DateTime.Now)); History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Assistant, totalResponse), DateTime.Now));
} }


public ChatSession<T> WithPrompt(string prompt)
public ChatSession<T> WithPrompt(string prompt, string encoding = "UTF-8")
{ {
_model.InitChatPrompt(prompt);
_model.InitChatPrompt(prompt, encoding);
return this; return this;
} }


public ChatSession<T> WithPromptFile(string promptFilename)
public ChatSession<T> WithPromptFile(string promptFilename, string encoding = "UTF-8")
{ {
return WithPrompt(File.ReadAllText(promptFilename));
return WithPrompt(File.ReadAllText(promptFilename), encoding);
} }


/// <summary> /// <summary>


+ 2
- 2
LLama/IChatModel.cs View File

@@ -7,12 +7,12 @@ namespace LLama
public interface IChatModel public interface IChatModel
{ {
string Name { get; } string Name { get; }
IEnumerable<string> Chat(string text, string? prompt = null);
IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8");
/// <summary> /// <summary>
/// Init a prompt for chat and automatically produce the next prompt during the chat. /// Init a prompt for chat and automatically produce the next prompt during the chat.
/// </summary> /// </summary>
/// <param name="prompt"></param> /// <param name="prompt"></param>
void InitChatPrompt(string prompt);
void InitChatPrompt(string prompt, string encoding = "UTF-8");
void InitChatAntiprompt(string[] antiprompt); void InitChatAntiprompt(string[] antiprompt);
} }
} }

+ 2
- 2
LLama/LLamaEmbedder.cs View File

@@ -25,7 +25,7 @@ namespace LLama
_ctx = Utils.llama_init_from_gpt_params(ref @params); _ctx = Utils.llama_init_from_gpt_params(ref @params);
} }


public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true)
public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true, string encoding = "UTF-8")
{ {
if(n_thread == -1) if(n_thread == -1)
{ {
@@ -36,7 +36,7 @@ namespace LLama
{ {
text = text.Insert(0, " "); text = text.Insert(0, " ");
} }
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos);
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos, encoding);


// TODO(Rinne): deal with log of prompt // TODO(Rinne): deal with log of prompt




+ 53
- 58
LLama/LLamaModel.cs View File

@@ -64,14 +64,13 @@ namespace LLama


} }


public unsafe LLamaModel(LLamaParams @params, string name = "", bool echo_input = false, bool verbose = false)
public unsafe LLamaModel(LLamaParams @params, string name = "", bool echo_input = false, bool verbose = false, string encoding = "UTF-8")
{ {
Name = name; Name = name;
_params = @params; _params = @params;
_ctx = Utils.llama_init_from_gpt_params(ref _params); _ctx = Utils.llama_init_from_gpt_params(ref _params);


// Add a space in front of the first character to match OG llama tokenizer behavior // Add a space in front of the first character to match OG llama tokenizer behavior
_params.prompt = _params.prompt.Insert(0, " ");
_session_tokens = new List<llama_token>(); _session_tokens = new List<llama_token>();


_path_session = @params.path_session; _path_session = @params.path_session;
@@ -100,50 +99,13 @@ namespace LLama
} }
} }


_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
_n_ctx = NativeApi.llama_n_ctx(_ctx); _n_ctx = NativeApi.llama_n_ctx(_ctx);


if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}

ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count && verbose)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else if(verbose)
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}

// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
{
_params.n_keep = _embed_inp.Count;
}
WithPrompt(_params.prompt);


// prefix & suffix for instruct mode // prefix & suffix for instruct mode
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false);
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true, encoding);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false, encoding);


// in instruct mode, we inject a prefix and a suffix to each input by the user // in instruct mode, we inject a prefix and a suffix to each input by the user
if (_params.instruct) if (_params.instruct)
@@ -159,7 +121,7 @@ namespace LLama
} }


// determine newline token // determine newline token
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false);
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false, encoding);


if (_params.verbose_prompt) if (_params.verbose_prompt)
{ {
@@ -211,7 +173,6 @@ namespace LLama


_is_antiprompt = false; _is_antiprompt = false;
_input_echo = echo_input; _input_echo = echo_input;
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
_n_past = 0; _n_past = 0;
_n_remain = _params.n_predict; _n_remain = _params.n_predict;
_n_consumed = 0; _n_consumed = 0;
@@ -219,18 +180,52 @@ namespace LLama
_embed = new List<llama_token>(); _embed = new List<llama_token>();
} }


public LLamaModel WithPrompt(string prompt)
public LLamaModel WithPrompt(string prompt, string encoding = "UTF-8")
{ {
_params.prompt = prompt;
if (!_params.prompt.EndsWith(" "))
_params.prompt = _params.prompt.Insert(0, " ");
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true, encoding);

if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}

ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}
// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
{ {
_params.prompt = _params.prompt.Insert(0, " ");
_params.n_keep = _embed_inp.Count;
} }
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4) if (_embed_inp.Count > _n_ctx - 4)
{ {
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})"); throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
} }
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
return this; return this;
} }


@@ -239,7 +234,7 @@ namespace LLama
return WithPrompt(File.ReadAllText(promptFileName)); return WithPrompt(File.ReadAllText(promptFileName));
} }


private string ProcessTextBeforeInfer(string text)
private string ProcessTextBeforeInfer(string text, string encoding)
{ {
if (!string.IsNullOrEmpty(_params.input_prefix)) if (!string.IsNullOrEmpty(_params.input_prefix))
{ {
@@ -265,7 +260,7 @@ namespace LLama
_embed_inp.AddRange(_inp_pfx); _embed_inp.AddRange(_inp_pfx);
} }


var line_inp = Utils.llama_tokenize(_ctx, text, false);
var line_inp = Utils.llama_tokenize(_ctx, text, false, encoding);
_embed_inp.AddRange(line_inp); _embed_inp.AddRange(line_inp);


// instruct mode: insert response suffix // instruct mode: insert response suffix
@@ -279,7 +274,7 @@ namespace LLama
return text; return text;
} }


public void InitChatPrompt(string prompt)
public void InitChatPrompt(string prompt, string encoding = "UTF-8")
{ {
WithPrompt(prompt); WithPrompt(prompt);
} }
@@ -289,7 +284,7 @@ namespace LLama
_params.antiprompt = antiprompt.ToList(); _params.antiprompt = antiprompt.ToList();
} }


public IEnumerable<string> Chat(string text, string? prompt = null)
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{ {
_params.interactive = true; _params.interactive = true;
_input_echo = false; _input_echo = false;
@@ -297,13 +292,13 @@ namespace LLama
{ {
WithPrompt(prompt); WithPrompt(prompt);
} }
return Call(text);
return Call(text, encoding);
} }


public IEnumerable<string> Call(string text)
public IEnumerable<string> Call(string text, string encoding = "UTF-8")
{ {
_is_interacting = _is_antiprompt = false; _is_interacting = _is_antiprompt = false;
ProcessTextBeforeInfer(text);
ProcessTextBeforeInfer(text, encoding);
while ((_n_remain != 0 || _params.interactive) && !_is_interacting) while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
{ {
@@ -320,7 +315,7 @@ namespace LLama
_n_past = Math.Max(1, _params.n_keep); _n_past = Math.Max(1, _params.n_keep);


// insert n_left/2 tokens at the start of embed from last_n_tokens // insert n_left/2 tokens at the start of embed from last_n_tokens
_embed.InsertRange(0, _last_n_tokens.GetRange(_n_ctx - n_left / 2 - _embed.Count, _embed.Count));
_embed.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embed.Count).Skip(_n_ctx - n_left / 2 - _embed.Count));


// stop saving session if we run out of context // stop saving session if we run out of context
_path_session = ""; _path_session = "";
@@ -494,7 +489,7 @@ namespace LLama
if (_params.antiprompt.Count != 0) if (_params.antiprompt.Count != 0)
{ {
// tokenize and inject first reverse prompt // tokenize and inject first reverse prompt
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false);
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false, encoding);
_embed_inp.AddRange(first_antiprompt); _embed_inp.AddRange(first_antiprompt);
} }
} }


+ 8
- 3
LLama/Utils.cs View File

@@ -47,11 +47,16 @@ namespace LLama
return ctx; return ctx;
} }


public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos)
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding)
{ {
llama_token[] res = new llama_token[text.Length + (add_bos ? 1 : 0)];
var cnt = Encoding.GetEncoding(encoding).GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos); int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
Debug.Assert(n >= 0);
if(n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
return res.Take(n).ToList(); return res.Take(n).ToList();
} }




Loading…
Cancel
Save