Browse Source

fix: errors when input is not English or too long.

tags/untagged-54cee5cf55a360e770ad
Yaohui Liu 2 years ago
parent
commit
afedd3c949
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
5 changed files with 71 additions and 71 deletions
  1. +6
    -6
      LLama/ChatSession.cs
  2. +2
    -2
      LLama/IChatModel.cs
  3. +2
    -2
      LLama/LLamaEmbedder.cs
  4. +53
    -58
      LLama/LLamaModel.cs
  5. +8
    -3
      LLama/Utils.cs

+ 6
- 6
LLama/ChatSession.cs View File

@@ -16,11 +16,11 @@ namespace LLama
_model = model;
}

public IEnumerable<string> Chat(string text, string? prompt = null)
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Human, text), DateTime.Now));
string totalResponse = "";
foreach(var response in _model.Chat(text, prompt))
foreach(var response in _model.Chat(text, prompt, encoding))
{
totalResponse += response;
yield return response;
@@ -28,15 +28,15 @@ namespace LLama
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Assistant, totalResponse), DateTime.Now));
}

public ChatSession<T> WithPrompt(string prompt)
public ChatSession<T> WithPrompt(string prompt, string encoding = "UTF-8")
{
_model.InitChatPrompt(prompt);
_model.InitChatPrompt(prompt, encoding);
return this;
}

public ChatSession<T> WithPromptFile(string promptFilename)
public ChatSession<T> WithPromptFile(string promptFilename, string encoding = "UTF-8")
{
return WithPrompt(File.ReadAllText(promptFilename));
return WithPrompt(File.ReadAllText(promptFilename), encoding);
}

/// <summary>


+ 2
- 2
LLama/IChatModel.cs View File

@@ -7,12 +7,12 @@ namespace LLama
public interface IChatModel
{
string Name { get; }
IEnumerable<string> Chat(string text, string? prompt = null);
IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8");
/// <summary>
/// Init a prompt for chat and automatically produce the next prompt during the chat.
/// </summary>
/// <param name="prompt"></param>
void InitChatPrompt(string prompt);
void InitChatPrompt(string prompt, string encoding = "UTF-8");
void InitChatAntiprompt(string[] antiprompt);
}
}

+ 2
- 2
LLama/LLamaEmbedder.cs View File

@@ -25,7 +25,7 @@ namespace LLama
_ctx = Utils.llama_init_from_gpt_params(ref @params);
}

public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true)
public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true, string encoding = "UTF-8")
{
if(n_thread == -1)
{
@@ -36,7 +36,7 @@ namespace LLama
{
text = text.Insert(0, " ");
}
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos);
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos, encoding);

// TODO(Rinne): deal with log of prompt



+ 53
- 58
LLama/LLamaModel.cs View File

@@ -64,14 +64,13 @@ namespace LLama

}

public unsafe LLamaModel(LLamaParams @params, string name = "", bool echo_input = false, bool verbose = false)
public unsafe LLamaModel(LLamaParams @params, string name = "", bool echo_input = false, bool verbose = false, string encoding = "UTF-8")
{
Name = name;
_params = @params;
_ctx = Utils.llama_init_from_gpt_params(ref _params);

// Add a space in front of the first character to match OG llama tokenizer behavior
_params.prompt = _params.prompt.Insert(0, " ");
_session_tokens = new List<llama_token>();

_path_session = @params.path_session;
@@ -100,50 +99,13 @@ namespace LLama
}
}

_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
_n_ctx = NativeApi.llama_n_ctx(_ctx);

if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}

ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count && verbose)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else if(verbose)
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}

// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
{
_params.n_keep = _embed_inp.Count;
}
WithPrompt(_params.prompt);

// prefix & suffix for instruct mode
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false);
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true, encoding);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false, encoding);

// in instruct mode, we inject a prefix and a suffix to each input by the user
if (_params.instruct)
@@ -159,7 +121,7 @@ namespace LLama
}

// determine newline token
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false);
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false, encoding);

if (_params.verbose_prompt)
{
@@ -211,7 +173,6 @@ namespace LLama

_is_antiprompt = false;
_input_echo = echo_input;
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
_n_past = 0;
_n_remain = _params.n_predict;
_n_consumed = 0;
@@ -219,18 +180,52 @@ namespace LLama
_embed = new List<llama_token>();
}

public LLamaModel WithPrompt(string prompt)
public LLamaModel WithPrompt(string prompt, string encoding = "UTF-8")
{
_params.prompt = prompt;
if (!_params.prompt.EndsWith(" "))
_params.prompt = _params.prompt.Insert(0, " ");
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true, encoding);

if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}

ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}
// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
{
_params.prompt = _params.prompt.Insert(0, " ");
_params.n_keep = _embed_inp.Count;
}
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
return this;
}

@@ -239,7 +234,7 @@ namespace LLama
return WithPrompt(File.ReadAllText(promptFileName));
}

private string ProcessTextBeforeInfer(string text)
private string ProcessTextBeforeInfer(string text, string encoding)
{
if (!string.IsNullOrEmpty(_params.input_prefix))
{
@@ -265,7 +260,7 @@ namespace LLama
_embed_inp.AddRange(_inp_pfx);
}

var line_inp = Utils.llama_tokenize(_ctx, text, false);
var line_inp = Utils.llama_tokenize(_ctx, text, false, encoding);
_embed_inp.AddRange(line_inp);

// instruct mode: insert response suffix
@@ -279,7 +274,7 @@ namespace LLama
return text;
}

public void InitChatPrompt(string prompt)
public void InitChatPrompt(string prompt, string encoding = "UTF-8")
{
WithPrompt(prompt);
}
@@ -289,7 +284,7 @@ namespace LLama
_params.antiprompt = antiprompt.ToList();
}

public IEnumerable<string> Chat(string text, string? prompt = null)
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{
_params.interactive = true;
_input_echo = false;
@@ -297,13 +292,13 @@ namespace LLama
{
WithPrompt(prompt);
}
return Call(text);
return Call(text, encoding);
}

public IEnumerable<string> Call(string text)
public IEnumerable<string> Call(string text, string encoding = "UTF-8")
{
_is_interacting = _is_antiprompt = false;
ProcessTextBeforeInfer(text);
ProcessTextBeforeInfer(text, encoding);
while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
{
@@ -320,7 +315,7 @@ namespace LLama
_n_past = Math.Max(1, _params.n_keep);

// insert n_left/2 tokens at the start of embed from last_n_tokens
_embed.InsertRange(0, _last_n_tokens.GetRange(_n_ctx - n_left / 2 - _embed.Count, _embed.Count));
_embed.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embed.Count).Skip(_n_ctx - n_left / 2 - _embed.Count));

// stop saving session if we run out of context
_path_session = "";
@@ -494,7 +489,7 @@ namespace LLama
if (_params.antiprompt.Count != 0)
{
// tokenize and inject first reverse prompt
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false);
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false, encoding);
_embed_inp.AddRange(first_antiprompt);
}
}


+ 8
- 3
LLama/Utils.cs View File

@@ -47,11 +47,16 @@ namespace LLama
return ctx;
}

public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos)
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding)
{
llama_token[] res = new llama_token[text.Length + (add_bos ? 1 : 0)];
var cnt = Encoding.GetEncoding(encoding).GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
Debug.Assert(n >= 0);
if(n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
return res.Take(n).ToList();
}



Loading…
Cancel
Save