Browse Source

feat: add ChatSession.

tags/v0.2.1
Yaohui Liu 2 years ago
parent
commit
fce10f3c4f
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
8 changed files with 224 additions and 191 deletions
  1. +46
    -0
      LLama/ChatSession.cs
  2. +81
    -46
      LLama/GptModel.cs
  3. +3
    -5
      LLama/GptParams.cs
  4. +18
    -0
      LLama/IChatModel.cs
  5. +1
    -1
      LLama/LLamaCache.cs
  6. +42
    -137
      LLama/LLamaModel.cs
  7. +8
    -1
      LLama/LLamaTypes.cs
  8. +25
    -1
      LLama/Utils.cs

+ 46
- 0
LLama/ChatSession.cs View File

@@ -0,0 +1,46 @@
using LLama.Types;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace LLama
{
public class ChatSession<T> where T: IChatModel
{
IChatModel _model;
List<ChatMessageRecord> History { get; } = new List<ChatMessageRecord>();
public ChatSession(T model)
{
_model = model;
}

public IEnumerable<string> Chat(string text, string? prompt = null)
{
return _model.Chat(text, prompt);
}

public ChatSession<T> WithPrompt(string prompt)
{
_model.InitChatPrompt(prompt);
return this;
}

public ChatSession<T> WithPromptFile(string promptFilename)
{
return WithPrompt(File.ReadAllText(promptFilename));
}

/// <summary>
/// Set the keyword to split the return value of chat AI.
/// </summary>
/// <param name="humanName"></param>
/// <returns></returns>
public ChatSession<T> WithAntiprompt(string[] antiprompt)
{
_model.InitChatAntiprompt(antiprompt);
return this;
}
}
}

LLama/Model.cs → LLama/GptModel.cs View File

@@ -12,9 +12,9 @@ using System.Diagnostics;
namespace LLama
{
using llama_token = Int32;
public class GptModel
public class LLamaModel: IChatModel
{
GptParams _params;
LLamaParams _params;
SafeLLamaContextHandle _ctx;
string _path_session;
List<llama_token> _session_tokens;
@@ -38,7 +38,12 @@ namespace LLama
int _n_session_consumed;
List<llama_token> _embed;

public GptModel(string model_path = "models/lamma-7B/ggml-model.bin", int seed = 0, int n_threads = -1, int n_predict = -1,
// params related to chat API only
bool _first_time_chat = true;

public string Name { get; set; }

public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1,
int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0,
Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
@@ -49,33 +54,18 @@ namespace LLama
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool embedding = false, bool interactive_first = false, bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false)
bool verbose_prompt = false) : this(new LLamaParams(seed, n_threads, n_predict, n_parts, n_ctx, n_batch,
n_keep, logit_bias, top_k, top_p, tfs_z, typical_p, temp, repeat_penalty, repeat_last_n, frequency_penalty,
presence_penalty, mirostat, mirostat_tau, mirostat_eta, model_path, prompt, path_session, input_prefix,
input_suffix, antiprompt, lora_adapter, lora_base, memory_f16, random_prompt, use_color, interactive, embedding,
interactive_first, instruct, penalize_nl, perplexity, use_mmap, use_mlock, mem_test, verbose_prompt), model_name, echo_input, verbose)
{

}

public GptModel WithPrompt(string prompt)
{
_params.prompt = prompt;
if(!_params.prompt.EndsWith(" "))
{
_params.prompt.Insert(0, " ");
}
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
return this;
}

public GptModel WithPromptFile(string promptFileName)
{
return WithPrompt(File.ReadAllText(promptFileName));
}

public unsafe GptModel(GptParams @params)
public unsafe LLamaModel(LLamaParams @params, string name = "", bool echo_input = false, bool verbose = false)
{
Name = name;
_params = @params;
_ctx = Utils.llama_init_from_gpt_params(ref _params);

@@ -86,7 +76,10 @@ namespace LLama
_path_session = @params.path_session;
if (!string.IsNullOrEmpty(_path_session))
{
Logger.Default.Info($"Attempting to load saved session from '{_path_session}'");
if (verbose)
{
Logger.Default.Info($"Attempting to load saved session from '{_path_session}'");
}

if (!File.Exists(_path_session))
{
@@ -100,7 +93,10 @@ namespace LLama
throw new RuntimeError($"Failed to load session file {_path_session}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
Logger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
if (verbose)
{
Logger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
}
}

_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
@@ -122,7 +118,7 @@ namespace LLama
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
if (n_matching_session_tokens >= (ulong)_embed_inp.Count && verbose)
{
Logger.Default.Info("Session file has exact match for prompt!");
}
@@ -131,7 +127,7 @@ namespace LLama
Logger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
else if(verbose)
{
Logger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
@@ -185,29 +181,35 @@ namespace LLama
Logger.Default.Info("\n");
}

if (_params.interactive)
if (_params.interactive && verbose)
{
Logger.Default.Info("interactive mode on.");
}
Logger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
if (verbose)
{
Logger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
$"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
$"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
$" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
$" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
Logger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}");
Logger.Default.Info("\n");
Logger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}");
Logger.Default.Info("\n");
}

_last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();

if (_params.interactive)
{
Logger.Default.Info("== Running in interactive mode. ==");
if (verbose)
{
Logger.Default.Info("== Running in interactive mode. ==");
}
_is_interacting = _params.interactive_first;
}

_is_antiprompt = false;
_input_echo = true;
_input_echo = echo_input;
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
_n_past = 0;
_n_remain = _params.n_predict;
@@ -216,6 +218,26 @@ namespace LLama
_embed = new List<llama_token>();
}

public LLamaModel WithPrompt(string prompt)
{
_params.prompt = prompt;
if (!_params.prompt.EndsWith(" "))
{
_params.prompt.Insert(0, " ");
}
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
return this;
}

public LLamaModel WithPromptFile(string promptFileName)
{
return WithPrompt(File.ReadAllText(promptFileName));
}

private string ProcessTextBeforeInfer(string text)
{
if (!string.IsNullOrEmpty(_params.input_prefix))
@@ -256,6 +278,27 @@ namespace LLama
return text;
}

public void InitChatPrompt(string prompt)
{
WithPrompt(prompt);
}

public void InitChatAntiprompt(string[] antiprompt)
{
_params.antiprompt = antiprompt.ToList();
}

public IEnumerable<string> Chat(string text, string? prompt = null)
{
_params.interactive = true;
_input_echo = false;
if (!string.IsNullOrEmpty(prompt))
{
WithPrompt(prompt);
}
return Call(text);
}

public IEnumerable<string> Call(string text)
{
_is_interacting = _is_antiprompt = false;
@@ -486,11 +529,7 @@ namespace LLama
{
foreach (var id in _embed)
{
#if NET6_0_OR_GREATER
yield return Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
#else
yield return Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id));
#endif
yield return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
}
}

@@ -501,11 +540,7 @@ namespace LLama
string last_output = "";
foreach (var id in _last_n_tokens)
{
#if NET6_0_OR_GREATER
last_output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
#else
last_output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, id));
#endif
last_output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
}

_is_antiprompt = false;

+ 3
- 5
LLama/GptParams.cs View File

@@ -1,12 +1,10 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Text;

namespace LLama
{
using llama_token = Int32;
public struct GptParams
public struct LLamaParams
{
public int seed; // RNG seed
public int n_threads = Math.Max(Environment.ProcessorCount / 2, 1); // number of threads (-1 = autodetect)
@@ -57,7 +55,7 @@ namespace LLama
public bool mem_test = false; // compute maximum memory usage
public bool verbose_prompt = false; // print prompt tokens before generation

public GptParams(int seed = 0, int n_threads = -1, int n_predict = -1,
public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1,
int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0,
Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
@@ -72,7 +70,7 @@ namespace LLama
bool verbose_prompt = false)
{
this.seed = seed;
if(n_threads != -1)
if (n_threads != -1)
{
this.n_threads = n_threads;
}


+ 18
- 0
LLama/IChatModel.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama
{
public interface IChatModel
{
string Name { get; }
IEnumerable<string> Chat(string text, string? prompt = null);
/// <summary>
/// Init a prompt for chat and automatically produce the next prompt during the chat.
/// </summary>
/// <param name="prompt"></param>
void InitChatPrompt(string prompt);
void InitChatAntiprompt(string[] antiprompt);
}
}

+ 1
- 1
LLama/LLamaCache.cs View File

@@ -70,7 +70,7 @@ namespace LLama
{
int minLen = 0;
llama_token[]? minKey = null;
var keys = _cacheState.Keys.Select(k => (k, LLamaModel.LongestTokenPrefix(k, key)));
var keys = _cacheState.Keys.Select(k => (k, LLamaModelV1.LongestTokenPrefix(k, key)));
foreach(var (k, prefixLen) in keys)
{
if(prefixLen > minLen)


+ 42
- 137
LLama/LLamaModel.cs View File

@@ -11,6 +11,7 @@ using System.Text;
using LLama.Types;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Collections;

namespace LLama
{
@@ -18,7 +19,8 @@ namespace LLama
/// <summary>
/// High-level Wrapper of a llama.cpp model for inference.
/// </summary>
public class LLamaModel
[Obsolete]
public class LLamaModelV1
{
private string _model_path;
LLamaContextParams _params;
@@ -59,7 +61,7 @@ namespace LLama
/// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param>
/// <param name="lora_path">Path to a LoRA file to apply to the model.</param>
/// <param name="verbose">Print verbose output to stderr.</param>
public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
@@ -115,13 +117,11 @@ namespace LLama

if (_verbose)
{
#if NET6_0_OR_GREATER
Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info()));
#endif
Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info()));
}
}

public LLamaModel(LLamaModel other)
public LLamaModelV1(LLamaModelV1 other)
{
_ctx = other._ctx;
_model_path = other._model_path;
@@ -166,15 +166,17 @@ namespace LLama
string output = "";
foreach(var token in tokens)
{
#if NET6_0_OR_GREATER
output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
#else
output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token));
#endif
output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
}
return output;
}

public string DeTokenize(llama_token token)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? "";
}

/// <summary>
/// Set the cache.
/// </summary>
@@ -436,12 +438,11 @@ namespace LLama
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty"></param>
/// <param name="top_k"></param>
/// <param name="stream"></param>
/// <returns>IEnumerable of Completion and CompletionChunk</returns>
/// <exception cref="ArgumentException"></exception>
private IEnumerable<object> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
private IEnumerable<CompletionChunk> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40, bool stream = false)
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string completionId = $"cmpl-{Guid.NewGuid()}";
@@ -539,7 +540,8 @@ namespace LLama
Reset();
}
//foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
while(true)
string allText = "";
while (true)
{
Eval(tokens);
var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
@@ -554,7 +556,7 @@ namespace LLama

completionTokens.Add(token);

string allText = DeTokenize(completionTokens);
allText = DeTokenize(completionTokens);

int cut = Math.Min(3, allText.Length);
for(int i = allText.Length - cut; i < allText.Length; i++)
@@ -585,34 +587,31 @@ namespace LLama
break;
}

if (stream)
var start = returnedCharacters;
int longest = 0;
foreach (var s in stop)
{
var start = returnedCharacters;
int longest = 0;
foreach(var s in stop)
for (int i = s.Length; i > 0; i--)
{
for(int i = s.Length; i > 0; i--)
if (allText.EndsWith(s.Substring(0, i)))
{
if(allText.EndsWith(s.Substring(0, i)))
if (i > longest)
{
if(i > longest)
{
longest = i;
}
break;
longest = i;
}
break;
}
}
text = allText.Substring(0, allText.Length - longest);
returnedCharacters += text.Skip(start).Count();
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
});
}
text = allText.Substring(0, allText.Length - longest);
returnedCharacters += text.Skip(start).Count();
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(start), 0, null, finishReason)
});
}

if(_cache is not null)
if (_cache is not null)
{
if (_verbose)
{
@@ -621,14 +620,6 @@ namespace LLama
_cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
}

if (stream)
{
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason)
});
}

string textStr = text;
if (echo)
{
@@ -673,11 +664,6 @@ namespace LLama
{
NativeApi.llama_print_timings(_ctx);
}

yield return new Completion(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text, 0, logProbs, finishReason)
}, new CompletionUsage(promptTokens.Count, completionTokens.Count, promptTokens.Count + completionTokens.Count));
}

/// <summary>
@@ -696,60 +682,11 @@ namespace LLama
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public IEnumerable<CompletionChunk> CreateCompletionStream(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
yield return (CompletionChunk)CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k, true);
}

/// <summary>
/// Generate text from a prompt.
/// </summary>
/// <param name="prompt">The prompt to generate text from.</param>
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
/// <param name="echo">Whether to echo the prompt.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public Completion CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
var completion = CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k, false).First();
return (Completion)completion;
}

/// <summary>
/// Generate text from a prompt.
/// </summary>
/// <param name="prompt">The prompt to generate text from.</param>
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
/// <param name="echo">Whether to echo the prompt.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public Completion Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
public IEnumerable<CompletionChunk> CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}

@@ -769,18 +706,18 @@ namespace LLama
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public IEnumerable<CompletionChunk> StreamCall(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
public IEnumerable<CompletionChunk> Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletionStream(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}

private ChatCompletion ConvertTextCompletionToChat(Completion completion)
{
return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage("assistant", completion.Choices[0].Text, null),
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text),
completion.Choices[0].FinishReason) }, completion.Usage);
}

@@ -801,39 +738,6 @@ namespace LLama
}
}

/// <summary>
/// Generate a chat completion from a list of messages.
/// </summary>
/// <param name="messages">A list of messages to generate a response for.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="presence_penalty"></param>
/// <param name="frequency_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <returns></returns>
public ChatCompletion CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
float repeat_penalty = 1.1f)
{
if(stop is null)
{
stop = new string[0];
}
string GetRole(ChatCompletionMessage message)
{
return message.Role == "user" ? "Human" : "Assistant";
}
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
var prompt = chatHistory + "### Assistant:";
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
return ConvertTextCompletionToChat(completion);
}

/// <summary>
/// Generate a chat completion from a list of messages and yield return the result.
/// </summary>
@@ -847,7 +751,7 @@ namespace LLama
/// <param name="frequency_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <returns></returns>
public IEnumerable<ChatCompletionChunk> CreateChatCompletionStream(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
public IEnumerable<ChatCompletionChunk> CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
float repeat_penalty = 1.1f)
{
@@ -857,12 +761,13 @@ namespace LLama
}
string GetRole(ChatCompletionMessage message)
{
return message.Role == "user" ? "Human" : "Assistant";
return message.Role == ChatRole.Human ? "Human" : "Assistant";
}
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
var prompt = chatHistory + "### Assistant:";
prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens));
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
var completion = StreamCall(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
return ConvertTextCompletionChunksToChat(completion);
}


+ 8
- 1
LLama/LLamaTypes.cs View File

@@ -4,6 +4,11 @@ using System.Text;

namespace LLama.Types
{
public enum ChatRole
{
Human,
Assistant
}
public record EmbeddingUsage(int PromptTokens, int TotalTokens);

public record EmbeddingData(int Index, string Object, float[] Embedding);
@@ -20,7 +25,7 @@ namespace LLama.Types

public record Completion(string Id, string Object, int Created, string Model, CompletionChoice[] Choices, CompletionUsage Usage);

public record ChatCompletionMessage(string Role, string Content, string? User);
public record ChatCompletionMessage(ChatRole Role, string Content, string? Name = null);

public record ChatCompletionChoice(int Index, ChatCompletionMessage Message, string? FinishReason);

@@ -31,4 +36,6 @@ namespace LLama.Types
public record ChatCompletionChunkChoice(int Index, ChatCompletionChunkDelta Delta, string? FinishReason);

public record ChatCompletionChunk(string Id, string Model, string Object, int Created, ChatCompletionChunkChoice[] Choices);

public record ChatMessageRecord(ChatCompletionMessage Message, DateTime Time);
}

+ 25
- 1
LLama/Utils.cs View File

@@ -5,13 +5,14 @@ using System.Text;
using LLama.Exceptions;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;

namespace LLama
{
using llama_token = Int32;
internal static class Utils
{
public static SafeLLamaContextHandle llama_init_from_gpt_params(ref GptParams @params)
public static SafeLLamaContextHandle llama_init_from_gpt_params(ref LLamaParams @params)
{
var lparams = NativeApi.llama_context_default_params();

@@ -58,5 +59,28 @@ namespace LLama
var logits = NativeApi.llama_get_logits(ctx);
return new Span<float>(logits, length);
}

public static unsafe string PtrToStringUTF8(IntPtr ptr)
{
#if NET6_0_OR_GREATER
return Marshal.PtrToStringUTF8(ptr);
#else
byte* tp = (byte*)ptr.ToPointer();
List<byte> bytes = new();
while (true)
{
byte c = *tp++;
if(c == '\0')
{
break;
}
else
{
bytes.Add(c);
}
}
return Encoding.UTF8.GetString(bytes.ToArray());
#endif
}
}
}

Loading…
Cancel
Save