Browse Source

feat: add gpt model.

tags/v0.2.1
Yaohui Liu 2 years ago
parent
commit
d6a7997e46
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
8 changed files with 783 additions and 17 deletions
  1. +21
    -8
      LLama.Console/Program.cs
  2. +134
    -0
      LLama/GptParams.cs
  3. +1
    -1
      LLama/LLamaModel.cs
  4. +547
    -0
      LLama/Model.cs
  5. +3
    -3
      LLama/Native/NativeApi.Sampling.cs
  6. +1
    -1
      LLama/Native/NativeApi.cs
  7. +14
    -4
      LLama/Native/SamplingApi.cs
  8. +62
    -0
      LLama/Utils.cs

+ 21
- 8
LLama.Console/Program.cs View File

@@ -1,16 +1,29 @@
using LLama;
using LLama.Types;

//string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_0.bin";
//LLamaModel model = new(modelPath, logits_all: false, verbose: false, n_ctx: 2048);
//List<ChatCompletionMessage> chats = new List<ChatCompletionMessage>();
//chats.Add(new ChatCompletionMessage("user", "Hi, Alice, I'm Rinne.", null));
//chats.Add(new ChatCompletionMessage("assistant", "Hi, Rinne, I'm Alice. What can I do for you?", null));
//Console.Write("You: ");
//var question = "This is a text classification task, below are the category list:\r\n1. Air Handler\r\n2. Tub/Shower\r\n3. Fireplace\r\n4. Bathroom\r\n5. Kitchen\r\n6. Powerwash roof eves and soffits\r\n\r\nFor example:\r\n1. \"Clear drain clog at kitchen sink\": Kitchen\r\n2. \"Change blower motor speed\": Air Handler\r\n3. \"Clear drain clog at tub/shower\": Bathroom\r\n4. \"Clear drain clog at toilet\": Bathroom\r\n\r\nPlease classify this text \"toilet clogged\" in provided list. output in json format: {\"category\": \"\", \"confidence\":0.0}";
//chats.Add(new ChatCompletionMessage("user", question, null));
//var output = model.CreateChatCompletion(chats, max_tokens: 1024);
//Console.WriteLine($"LLama AI: {output.Choices[0].Message.Content}");

string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_0.bin";
LLamaModel model = new(modelPath, logits_all: false, verbose: false);
List<ChatCompletionMessage> chats = new List<ChatCompletionMessage>();
chats.Add(new ChatCompletionMessage("user", "Hi, Alice, I'm Rinne.", null));
chats.Add(new ChatCompletionMessage("assistant", "Hi, Rinne, I'm Alice. What can I do for you?", null));
GptModel model = new(new GptParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: new List<string>(){"User:"},
repeat_penalty: 1.0f));
model = model.WithPrompt("Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:");
while (true)
{
Console.Write("You: ");
Console.ForegroundColor = ConsoleColor.Green;
var question = Console.ReadLine();
chats.Add(new ChatCompletionMessage("user", question, null));
var output = model.CreateChatCompletion(chats, max_tokens: 256);
Console.WriteLine($"LLama AI: {output.Choices[0].Message.Content}");
Console.ForegroundColor = ConsoleColor.White;
var outputs = model.Call(question);
foreach (var output in outputs)
{
Console.Write(output);
}
}

+ 134
- 0
LLama/GptParams.cs View File

@@ -0,0 +1,134 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Text;

namespace LLama
{
using llama_token = Int32;
public struct GptParams
{
public int seed; // RNG seed
public int n_threads = Math.Max(Environment.ProcessorCount / 2, 1); // number of threads (-1 = autodetect)
public int n_predict = -1; // new tokens to predict
public int n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
public int n_ctx = 512; // context size
public int n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
public int n_keep = 0; // number of tokens to keep from initial prompt

// sampling parameters
public Dictionary<llama_token, float> logit_bias; // logit bias for specific tokens
public int top_k = 40; // <= 0 to use vocab size
public float top_p = 0.95f; // 1.0 = disabled
public float tfs_z = 1.00f; // 1.0 = disabled
public float typical_p = 1.00f; // 1.0 = disabled
public float temp = 0.80f; // 1.0 = disabled
public float repeat_penalty = 1.10f; // 1.0 = disabled
public int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
public float frequency_penalty = 0.00f; // 0.0 = disabled
public float presence_penalty = 0.00f; // 0.0 = disabled
public int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
public float mirostat_tau = 5.00f; // target entropy
public float mirostat_eta = 0.10f; // learning rate

public string model = "models/lamma-7B/ggml-model.bin"; // model path
public string prompt = ""; // initial prompt (set to empty string for interactive mode)
public string path_session = ""; // path to file for saving/loading model eval state
public string input_prefix = ""; // string to prefix user inputs with
public string input_suffix = ""; // string to suffix user inputs with
public List<string> antiprompt; // string upon seeing which more user input is prompted

public string lora_adapter = ""; // lora adapter path
public string lora_base = ""; // base model path for the lora adapter

public bool memory_f16 = true; // use f16 instead of f32 for memory kv
public bool random_prompt = false; // randomize prompt if none provided
public bool use_color = false; // use color to distinguish generations and inputs
public bool interactive = false; // interactive mode

public bool embedding = false; // get only sentence embedding
public bool interactive_first = false; // wait for user input immediately

public bool instruct = false; // instruction mode (used for Alpaca models)
public bool penalize_nl = true; // consider newlines as a repeatable token
public bool perplexity = false; // compute perplexity over the prompt
public bool use_mmap = true; // use mmap for faster loads
public bool use_mlock = false; // use mlock to keep model in memory
public bool mem_test = false; // compute maximum memory usage
public bool verbose_prompt = false; // print prompt tokens before generation

public GptParams(int seed = 0, int n_threads = -1, int n_predict = -1,
int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0,
Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f,
string model = "models/lamma-7B/ggml-model.bin", string prompt = "",
string path_session = "", string input_prefix = "", string input_suffix = "",
List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool embedding = false, bool interactive_first = false, bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false)
{
this.seed = seed;
if(n_threads != -1)
{
this.n_threads = n_threads;
}
this.n_predict = n_predict;
this.n_parts = n_parts;
this.n_ctx = n_ctx;
this.n_batch = n_batch;
this.n_keep = n_keep;

if (logit_bias == null)
{
logit_bias = new Dictionary<llama_token, float>();
}
this.logit_bias = logit_bias;

this.top_k = top_k;
this.top_p = top_p;
this.tfs_z = tfs_z;
this.typical_p = typical_p;
this.temp = temp;
this.repeat_penalty = repeat_penalty;
this.repeat_last_n = repeat_last_n;
this.frequency_penalty = frequency_penalty;
this.presence_penalty = presence_penalty;
this.mirostat = mirostat;
this.mirostat_tau = mirostat_tau;
this.mirostat_eta = mirostat_eta;

this.model = model;
this.prompt = prompt;
this.path_session = path_session;
this.input_prefix = input_prefix;
this.input_suffix = input_suffix;

if (antiprompt == null)
{
antiprompt = new List<string>();
}
this.antiprompt = antiprompt;

this.lora_adapter = lora_adapter;
this.lora_base = lora_base;

this.memory_f16 = memory_f16;
this.random_prompt = random_prompt;
this.use_color = use_color;
this.interactive = interactive;
this.embedding = embedding;
this.interactive_first = interactive_first;
this.instruct = instruct;
this.penalize_nl = penalize_nl;
this.perplexity = perplexity;
this.use_mmap = use_mmap;
this.use_mlock = use_mlock;
this.mem_test = mem_test;
this.verbose_prompt = verbose_prompt;
}
}
}

+ 1
- 1
LLama/LLamaModel.cs View File

@@ -608,7 +608,7 @@ namespace LLama
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
});
});
}
}



+ 547
- 0
LLama/Model.cs View File

@@ -0,0 +1,547 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using LLama.Exceptions;
using System.Linq;
using System.Text.RegularExpressions;
using System.Runtime.InteropServices;
using System.Diagnostics;

namespace LLama
{
using llama_token = Int32;
public class GptModel
{
GptParams _params;
SafeLLamaContextHandle _ctx;
string _path_session;
List<llama_token> _session_tokens;
List<llama_token> _embed_inp;
int _n_ctx;
List<llama_token> _inp_pfx;
List<llama_token> _inp_sfx;
List<llama_token> _llama_token_newline;
List<llama_token> _last_n_tokens;
bool _is_interacting;
bool _is_antiprompt;
bool _input_echo;

// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool _need_to_save_session;
int _n_past;
int _n_remain;
int _n_consumed;
int _n_session_consumed;
List<llama_token> _embed;

public GptModel(string model_path = "models/lamma-7B/ggml-model.bin", int seed = 0, int n_threads = -1, int n_predict = -1,
int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0,
Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, string prompt = "",
string path_session = "", string input_prefix = "", string input_suffix = "",
List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool embedding = false, bool interactive_first = false, bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false)
{

}

public GptModel WithPrompt(string prompt)
{
_params.prompt = prompt;
if(!_params.prompt.EndsWith(" "))
{
_params.prompt.Insert(0, " ");
}
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
return this;
}

public GptModel WithPromptFile(string promptFileName)
{
return WithPrompt(File.ReadAllText(promptFileName));
}

public unsafe GptModel(GptParams @params)
{
_params = @params;
_ctx = Utils.llama_init_from_gpt_params(ref _params);

// Add a space in front of the first character to match OG llama tokenizer behavior
_params.prompt.Insert(0, " ");
_session_tokens = new List<llama_token>();

_path_session = @params.path_session;
if (!string.IsNullOrEmpty(_path_session))
{
Logger.Default.Info($"Attempting to load saved session from '{_path_session}'");

if (!File.Exists(_path_session))
{
Logger.Default.Warn("Session file does not exist, will create.");
}

llama_token[] session_tokens = new llama_token[@params.n_ctx];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(_ctx, _path_session, session_tokens, (ulong)@params.n_ctx, &n_token_count_out))
{
throw new RuntimeError($"Failed to load session file {_path_session}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
Logger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
}

_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
_n_ctx = NativeApi.llama_n_ctx(_ctx);

if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}

ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}

// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > (int)_embed_inp.Count || _params.instruct)
{
_params.n_keep = _embed_inp.Count;
}

// prefix & suffix for instruct mode
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false);

// in instruct mode, we inject a prefix and a suffix to each input by the user
if (_params.instruct)
{
_params.interactive_first = true;
_params.antiprompt.Add("### Instruction:\n\n");
}

// enable interactive mode if reverse prompt or interactive start is specified
if (_params.antiprompt.Count != 0 || _params.interactive_first)
{
_params.interactive = true;
}

// determine newline token
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false);

if (_params.verbose_prompt)
{
Logger.Default.Info("\n");
Logger.Default.Info($"prompt: '{_params.prompt}'");
Logger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
for (int i = 0; i < _embed_inp.Count; i++)
{
Logger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
}
if (_params.n_keep > 0)
{
Logger.Default.Info($"static prompt based on n_keep: '");
for (int i = 0; i < _params.n_keep; i++)
{
Logger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
}
Logger.Default.Info("\n");
}
Logger.Default.Info("\n");
}

if (_params.interactive)
{
Logger.Default.Info("interactive mode on.");
}
Logger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
$"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
$"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
$" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
$" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
Logger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}");
Logger.Default.Info("\n");

_last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();

if (_params.interactive)
{
Logger.Default.Info("== Running in interactive mode. ==");
_is_interacting = _params.interactive_first;
}

_is_antiprompt = false;
_input_echo = true;
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
_n_past = 0;
_n_remain = _params.n_predict;
_n_consumed = 0;
_n_session_consumed = 0;
_embed = new List<llama_token>();
}

private string ProcessTextBeforeInfer(string text)
{
if (!string.IsNullOrEmpty(_params.input_prefix))
{
text = _params.input_prefix + text;
}
if (!text.EndsWith("\n"))
{
text += "\n";
}
if (text.Length > 1)
{
// append input suffix if any
if (!string.IsNullOrEmpty(_params.input_suffix))
{
text += _params.input_suffix;
Console.Write(_params.input_suffix);
}

// instruct mode: insert instruction prefix
if (_params.instruct && !_is_antiprompt)
{
_n_consumed = _embed_inp.Count;
_embed_inp.AddRange(_inp_pfx);
}

var line_inp = Utils.llama_tokenize(_ctx, text, false);
_embed_inp.AddRange(line_inp);

// instruct mode: insert response suffix
if (_params.instruct)
{
_embed_inp.AddRange(_inp_sfx);
}

_n_remain -= line_inp.Count;
}
return text;
}

public IEnumerable<string> Call(string text)
{
_is_interacting = _is_antiprompt = false;
ProcessTextBeforeInfer(text);
while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
{
if (_embed.Count > 0)
{
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (_n_past + _embed.Count > _n_ctx)
{
int n_left = _n_past - _params.n_keep;

_n_past = _params.n_keep;

// insert n_left/2 tokens at the start of embed from last_n_tokens
_embed.InsertRange(0, _last_n_tokens.GetRange(_n_ctx - n_left / 2 - _embed.Count, _embed.Count));

// stop saving session if we run out of context
_path_session = "";

// Console.WriteLine("\n---\n");
// Console.Write("resetting: '");
// for (int i = 0; i < embed.Count; i++) {
// Console.Write(llama_token_to_str(ctx, embed[i]));
// }
// Console.WriteLine("'\n");
// Console.WriteLine("\n---\n");
}

// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (_n_session_consumed < _session_tokens.Count)
{
int i = 0;
for (; i < _embed.Count; i++)
{
if (!_embed[i].Equals(_session_tokens[_n_session_consumed]))
{
_session_tokens.RemoveRange(_n_session_consumed, _session_tokens.Count - _n_session_consumed);
break;
}

_n_past++;
_n_session_consumed++;

if (_n_session_consumed >= _session_tokens.Count)
{
i++;
break;
}
}

if (i > 0)
{
_embed.RemoveRange(0, i);
}
}

// evaluate tokens in batches
// embed is typically prepared beforehand to fit within a batch, but not always
for (int i = 0; i < _embed.Count; i += _params.n_batch)
{
int n_eval = _embed.Count - i;

if (n_eval > _params.n_batch)
{
n_eval = _params.n_batch;
}

var array = _embed.GetRange(i, n_eval).ToArray();
if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
{
Logger.Default.Error($"Failed to eval");
throw new RuntimeError("Failed to eval");
}

_n_past += n_eval;
}

if (_embed.Count > 0 && !string.IsNullOrEmpty(_path_session))
{
_session_tokens.AddRange(_embed);
_n_session_consumed = _session_tokens.Count;
}
}

_embed.Clear();

if (_embed_inp.Count <= _n_consumed && !_is_interacting)
{
var temp = _params.temp;
var top_k = _params.top_k <= 0 ? NativeApi.llama_n_vocab(_ctx) : _params.top_k;
var top_p = _params.top_p;
var tfs_z = _params.tfs_z;
var typical_p = _params.typical_p;
var repeat_last_n = _params.repeat_last_n < 0 ? _n_ctx : _params.repeat_last_n;
var repeat_penalty = _params.repeat_penalty;
var alpha_presence = _params.presence_penalty;
var alpha_frequency = _params.frequency_penalty;
var mirostat = _params.mirostat;
var mirostat_tau = _params.mirostat_tau;
var mirostat_eta = _params.mirostat_eta;
var penalize_nl = _params.penalize_nl;

// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_path_session) && _need_to_save_session)
{
_need_to_save_session = false;
NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count);
}

llama_token id = 0;

{
var n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = Utils.llama_get_logits(_ctx, n_vocab);

// Apply params.logit_bias map
foreach (KeyValuePair<int, float> it in _params.logit_bias)
{
logits[it.Key] += it.Value;
}

var candidates = new List<LLamaTokenData>();
candidates.Capacity = n_vocab;
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f));
}

LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false);

// Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl()];
var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
_last_n_tokens.GetRange(_last_n_tokens.Count - last_n_repeat, last_n_repeat).ToArray(),
(ulong)last_n_repeat, repeat_penalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
_last_n_tokens.GetRange(_last_n_tokens.Count - last_n_repeat, last_n_repeat).ToArray(),
(ulong)last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[NativeApi.llama_token_nl()] = nl_logit;
}

if (temp <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(_ctx, candidates_p);
}
else
{
if (mirostat == 1)
{
float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu);
}
else if (mirostat == 2)
{
float mirostat_mu = 2.0f * mirostat_tau;
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates_p, top_k, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates_p, tfs_z, 1);
SamplingApi.llama_sample_typical(_ctx, candidates_p, typical_p, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates_p, top_p, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token(_ctx, candidates_p);
}
}

_last_n_tokens.RemoveAt(0);
_last_n_tokens.Add(id);
}

// replace end of text token with newline token when in interactive mode
if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
{
id = _llama_token_newline[0];
if (_params.antiprompt.Count != 0)
{
// tokenize and inject first reverse prompt
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false);
_embed_inp.AddRange(first_antiprompt);
}
}

// add it to the context
_embed.Add(id);

// echo this to console
_input_echo = true;

// decrement remaining sampling budget
_n_remain--;
}
else
{
// Assuming that the necessary variables have been defined and initialized,
// the C# equivalent code could be:

while (_embed_inp.Count > _n_consumed)
{
_embed.Add(_embed_inp[_n_consumed]);
_last_n_tokens.RemoveAt(0);
_last_n_tokens.Add(_embed_inp[_n_consumed]);
_n_consumed++;
if (_embed.Count >= _params.n_batch)
{
break;
}
}
}

if (_input_echo)
{
foreach (var id in _embed)
{
#if NET6_0_OR_GREATER
yield return Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
#else
yield return Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id));
#endif
}
}

if (_params.interactive && _embed_inp.Count <= _n_consumed)
{
if (_params.antiprompt.Count > 0)
{
string last_output = "";
foreach (var id in _last_n_tokens)
{
#if NET6_0_OR_GREATER
last_output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
#else
last_output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id));
#endif
}

_is_antiprompt = false;
foreach (var antiprompt in _params.antiprompt)
{
if (last_output.EndsWith(antiprompt))
{
_is_interacting = true;
_is_antiprompt = true;
break;
}
}
}

if(_n_past > 0 && _is_interacting)
{
_input_echo = false;
break;
}

if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
{
if (_params.instruct) {
_is_interacting = true;
} else
{
Logger.Default.Info(" [end of text]");
}
}

if (_params.interactive && _n_remain <= 0 && _params.n_predict != -1) {
_n_remain = _params.n_predict;
_is_interacting = true;
}
}
}
}
}
}

+ 3
- 3
LLama/Native/NativeApi.Sampling.cs View File

@@ -6,7 +6,7 @@ using System.Text;
namespace LLama.Native
{
using llama_token = Int32;
internal partial class NativeApi
internal unsafe partial class NativeApi
{
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -93,7 +93,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName)]
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float[] mu);
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu);

/// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -105,7 +105,7 @@ namespace LLama.Native
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName)]
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float[] mu);
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu);

/// <summary>
/// Selects the token with the highest probability.


+ 1
- 1
LLama/Native/NativeApi.cs View File

@@ -120,7 +120,7 @@ namespace LLama.Native
/// <param name="n_token_count_out"></param>
/// <returns></returns>
[DllImport(libraryName)]
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong[] n_token_count_out);
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);

/// <summary>
/// Save session file


+ 14
- 4
LLama/Native/SamplingApi.cs View File

@@ -148,14 +148,19 @@ namespace LLama.Native
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, float[] mu)
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, in float mu)
{
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
return NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, mu);
llama_token res;
fixed(float* pmu = &mu)
{
res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu);
}
return res;
}

/// <summary>
@@ -167,14 +172,19 @@ namespace LLama.Native
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, float[] mu)
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, in float mu)
{
var handle = candidates.data.Pin();
var st = new LLamaTokenDataArrayNative();
st.data = new IntPtr(handle.Pointer);
st.size = candidates.size;
st.sorted = candidates.sorted;
return NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, mu);
llama_token res;
fixed (float* pmu = &mu)
{
res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu);
}
return res;
}

/// <summary>


+ 62
- 0
LLama/Utils.cs View File

@@ -0,0 +1,62 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using System.Diagnostics;
using System.Linq;

namespace LLama
{
using llama_token = Int32;
internal static class Utils
{
public static SafeLLamaContextHandle llama_init_from_gpt_params(ref GptParams @params)
{
var lparams = NativeApi.llama_context_default_params();

lparams.n_ctx = @params.n_ctx;
lparams.n_parts = @params.n_parts;
lparams.seed = @params.seed;
lparams.f16_kv = @params.memory_f16;
lparams.use_mmap = @params.use_mmap;
lparams.use_mlock = @params.use_mlock;
lparams.logits_all = @params.perplexity;
lparams.embedding = @params.embedding;

var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams);

if(ctx_ptr == IntPtr.Zero )
{
throw new RuntimeError($"Failed to load model {@params.model}.");
}

SafeLLamaContextHandle ctx = new(ctx_ptr);

if (!string.IsNullOrEmpty(@params.lora_adapter))
{
int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter,
string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads);
if(err != 0)
{
throw new RuntimeError("Failed to apply lora adapter.");
}
}
return ctx;
}

public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos)
{
llama_token[] res = new llama_token[text.Length + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
Debug.Assert(n >= 0);
return res.Take(n).ToList();
}

public unsafe static Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length)
{
var logits = NativeApi.llama_get_logits(ctx);
return new Span<float>(logits, length);
}
}
}

Loading…
Cancel
Save