|
|
|
@@ -11,6 +11,7 @@ using System.Text; |
|
|
|
using LLama.Types; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System.Text.RegularExpressions; |
|
|
|
using System.Collections; |
|
|
|
|
|
|
|
namespace LLama |
|
|
|
{ |
|
|
|
@@ -18,7 +19,8 @@ namespace LLama |
|
|
|
/// <summary> |
|
|
|
/// High-level Wrapper of a llama.cpp model for inference. |
|
|
|
/// </summary> |
|
|
|
public class LLamaModel |
|
|
|
[Obsolete] |
|
|
|
public class LLamaModelV1 |
|
|
|
{ |
|
|
|
private string _model_path; |
|
|
|
LLamaContextParams _params; |
|
|
|
@@ -59,7 +61,7 @@ namespace LLama |
|
|
|
/// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param> |
|
|
|
/// <param name="lora_path">Path to a LoRA file to apply to the model.</param> |
|
|
|
/// <param name="verbose">Print verbose output to stderr.</param> |
|
|
|
public LLamaModel(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337, |
|
|
|
public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337, |
|
|
|
bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true, |
|
|
|
bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512, |
|
|
|
int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true) |
|
|
|
@@ -115,13 +117,11 @@ namespace LLama |
|
|
|
|
|
|
|
if (_verbose) |
|
|
|
{ |
|
|
|
#if NET6_0_OR_GREATER |
|
|
|
Logger.Default.Info(Marshal.PtrToStringUTF8(NativeApi.llama_print_system_info())); |
|
|
|
#endif |
|
|
|
Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info())); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public LLamaModel(LLamaModel other) |
|
|
|
public LLamaModelV1(LLamaModelV1 other) |
|
|
|
{ |
|
|
|
_ctx = other._ctx; |
|
|
|
_model_path = other._model_path; |
|
|
|
@@ -166,15 +166,17 @@ namespace LLama |
|
|
|
string output = ""; |
|
|
|
foreach(var token in tokens) |
|
|
|
{ |
|
|
|
#if NET6_0_OR_GREATER |
|
|
|
output += Marshal.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)); |
|
|
|
#else |
|
|
|
output += Marshal.PtrToStringAnsi(NativeApi.llama_token_to_str(_ctx, token)); |
|
|
|
#endif |
|
|
|
output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)); |
|
|
|
} |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
public string DeTokenize(llama_token token) |
|
|
|
{ |
|
|
|
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); |
|
|
|
return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? ""; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Set the cache. |
|
|
|
/// </summary> |
|
|
|
@@ -436,12 +438,11 @@ namespace LLama |
|
|
|
/// <param name="presence_penalty"></param> |
|
|
|
/// <param name="repeat_penalty"></param> |
|
|
|
/// <param name="top_k"></param> |
|
|
|
/// <param name="stream"></param> |
|
|
|
/// <returns>IEnumerable of Completion and CompletionChunk</returns> |
|
|
|
/// <exception cref="ArgumentException"></exception> |
|
|
|
private IEnumerable<object> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f, |
|
|
|
private IEnumerable<CompletionChunk> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f, |
|
|
|
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40, bool stream = false) |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) |
|
|
|
{ |
|
|
|
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); |
|
|
|
string completionId = $"cmpl-{Guid.NewGuid()}"; |
|
|
|
@@ -539,7 +540,8 @@ namespace LLama |
|
|
|
Reset(); |
|
|
|
} |
|
|
|
//foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty)) |
|
|
|
while(true) |
|
|
|
string allText = ""; |
|
|
|
while (true) |
|
|
|
{ |
|
|
|
Eval(tokens); |
|
|
|
var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty); |
|
|
|
@@ -554,7 +556,7 @@ namespace LLama |
|
|
|
|
|
|
|
completionTokens.Add(token); |
|
|
|
|
|
|
|
string allText = DeTokenize(completionTokens); |
|
|
|
allText = DeTokenize(completionTokens); |
|
|
|
|
|
|
|
int cut = Math.Min(3, allText.Length); |
|
|
|
for(int i = allText.Length - cut; i < allText.Length; i++) |
|
|
|
@@ -585,34 +587,31 @@ namespace LLama |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
if (stream) |
|
|
|
var start = returnedCharacters; |
|
|
|
int longest = 0; |
|
|
|
foreach (var s in stop) |
|
|
|
{ |
|
|
|
var start = returnedCharacters; |
|
|
|
int longest = 0; |
|
|
|
foreach(var s in stop) |
|
|
|
for (int i = s.Length; i > 0; i--) |
|
|
|
{ |
|
|
|
for(int i = s.Length; i > 0; i--) |
|
|
|
if (allText.EndsWith(s.Substring(0, i))) |
|
|
|
{ |
|
|
|
if(allText.EndsWith(s.Substring(0, i))) |
|
|
|
if (i > longest) |
|
|
|
{ |
|
|
|
if(i > longest) |
|
|
|
{ |
|
|
|
longest = i; |
|
|
|
} |
|
|
|
break; |
|
|
|
longest = i; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
text = allText.Substring(0, allText.Length - longest); |
|
|
|
returnedCharacters += text.Skip(start).Count(); |
|
|
|
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] |
|
|
|
{ |
|
|
|
new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) |
|
|
|
}); |
|
|
|
} |
|
|
|
text = allText.Substring(0, allText.Length - longest); |
|
|
|
returnedCharacters += text.Skip(start).Count(); |
|
|
|
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] |
|
|
|
{ |
|
|
|
new CompletionChoice(text.Substring(start), 0, null, finishReason) |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
if(_cache is not null) |
|
|
|
if (_cache is not null) |
|
|
|
{ |
|
|
|
if (_verbose) |
|
|
|
{ |
|
|
|
@@ -621,14 +620,6 @@ namespace LLama |
|
|
|
_cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState(); |
|
|
|
} |
|
|
|
|
|
|
|
if (stream) |
|
|
|
{ |
|
|
|
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] |
|
|
|
{ |
|
|
|
new CompletionChoice(text.Substring(returnedCharacters), 0, null, finishReason) |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
string textStr = text; |
|
|
|
if (echo) |
|
|
|
{ |
|
|
|
@@ -673,11 +664,6 @@ namespace LLama |
|
|
|
{ |
|
|
|
NativeApi.llama_print_timings(_ctx); |
|
|
|
} |
|
|
|
|
|
|
|
yield return new Completion(completionId, "text_completion", created, _model_path, new CompletionChoice[] |
|
|
|
{ |
|
|
|
new CompletionChoice(text, 0, logProbs, finishReason) |
|
|
|
}, new CompletionUsage(promptTokens.Count, completionTokens.Count, promptTokens.Count + completionTokens.Count)); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -696,60 +682,11 @@ namespace LLama |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <param name="top_k">The top-k value to use for sampling.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public IEnumerable<CompletionChunk> CreateCompletionStream(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) |
|
|
|
{ |
|
|
|
yield return (CompletionChunk)CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
frequency_penalty, presence_penalty, repeat_penalty, top_k, true); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Generate text from a prompt. |
|
|
|
/// </summary> |
|
|
|
/// <param name="prompt">The prompt to generate text from.</param> |
|
|
|
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param> |
|
|
|
/// <param name="max_tokens">The maximum number of tokens to generate.</param> |
|
|
|
/// <param name="temperature">The temperature to use for sampling.</param> |
|
|
|
/// <param name="top_p">The top-p value to use for sampling.</param> |
|
|
|
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param> |
|
|
|
/// <param name="echo">Whether to echo the prompt.</param> |
|
|
|
/// <param name="stop">A list of strings to stop generation when encountered.</param> |
|
|
|
/// <param name="frequency_penalty"></param> |
|
|
|
/// <param name="presence_penalty"></param> |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <param name="top_k">The top-k value to use for sampling.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public Completion CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) |
|
|
|
{ |
|
|
|
var completion = CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
frequency_penalty, presence_penalty, repeat_penalty, top_k, false).First(); |
|
|
|
return (Completion)completion; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Generate text from a prompt. |
|
|
|
/// </summary> |
|
|
|
/// <param name="prompt">The prompt to generate text from.</param> |
|
|
|
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param> |
|
|
|
/// <param name="max_tokens">The maximum number of tokens to generate.</param> |
|
|
|
/// <param name="temperature">The temperature to use for sampling.</param> |
|
|
|
/// <param name="top_p">The top-p value to use for sampling.</param> |
|
|
|
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param> |
|
|
|
/// <param name="echo">Whether to echo the prompt.</param> |
|
|
|
/// <param name="stop">A list of strings to stop generation when encountered.</param> |
|
|
|
/// <param name="frequency_penalty"></param> |
|
|
|
/// <param name="presence_penalty"></param> |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <param name="top_k">The top-k value to use for sampling.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public Completion Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
public IEnumerable<CompletionChunk> CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) |
|
|
|
{ |
|
|
|
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
frequency_penalty, presence_penalty, repeat_penalty, top_k); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -769,18 +706,18 @@ namespace LLama |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <param name="top_k">The top-k value to use for sampling.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public IEnumerable<CompletionChunk> StreamCall(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
public IEnumerable<CompletionChunk> Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, |
|
|
|
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, |
|
|
|
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) |
|
|
|
{ |
|
|
|
return CreateCompletionStream(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, |
|
|
|
frequency_penalty, presence_penalty, repeat_penalty, top_k); |
|
|
|
} |
|
|
|
|
|
|
|
private ChatCompletion ConvertTextCompletionToChat(Completion completion) |
|
|
|
{ |
|
|
|
return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model, |
|
|
|
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage("assistant", completion.Choices[0].Text, null), |
|
|
|
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text), |
|
|
|
completion.Choices[0].FinishReason) }, completion.Usage); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -801,39 +738,6 @@ namespace LLama |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Generate a chat completion from a list of messages. |
|
|
|
/// </summary> |
|
|
|
/// <param name="messages">A list of messages to generate a response for.</param> |
|
|
|
/// <param name="temperature">The temperature to use for sampling.</param> |
|
|
|
/// <param name="top_p">The top-p value to use for sampling.</param> |
|
|
|
/// <param name="top_k">The top-k value to use for sampling.</param> |
|
|
|
/// <param name="stop">A list of strings to stop generation when encountered.</param> |
|
|
|
/// <param name="max_tokens">The maximum number of tokens to generate.</param> |
|
|
|
/// <param name="presence_penalty"></param> |
|
|
|
/// <param name="frequency_penalty"></param> |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public ChatCompletion CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f, |
|
|
|
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f, |
|
|
|
float repeat_penalty = 1.1f) |
|
|
|
{ |
|
|
|
if(stop is null) |
|
|
|
{ |
|
|
|
stop = new string[0]; |
|
|
|
} |
|
|
|
string GetRole(ChatCompletionMessage message) |
|
|
|
{ |
|
|
|
return message.Role == "user" ? "Human" : "Assistant"; |
|
|
|
} |
|
|
|
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}")); |
|
|
|
var prompt = chatHistory + "### Assistant:"; |
|
|
|
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray(); |
|
|
|
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, |
|
|
|
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty); |
|
|
|
return ConvertTextCompletionToChat(completion); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Generate a chat completion from a list of messages and yield return the result. |
|
|
|
/// </summary> |
|
|
|
@@ -847,7 +751,7 @@ namespace LLama |
|
|
|
/// <param name="frequency_penalty"></param> |
|
|
|
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public IEnumerable<ChatCompletionChunk> CreateChatCompletionStream(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f, |
|
|
|
public IEnumerable<ChatCompletionChunk> CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f, |
|
|
|
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f, |
|
|
|
float repeat_penalty = 1.1f) |
|
|
|
{ |
|
|
|
@@ -857,12 +761,13 @@ namespace LLama |
|
|
|
} |
|
|
|
string GetRole(ChatCompletionMessage message) |
|
|
|
{ |
|
|
|
return message.Role == "user" ? "Human" : "Assistant"; |
|
|
|
return message.Role == ChatRole.Human ? "Human" : "Assistant"; |
|
|
|
} |
|
|
|
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}")); |
|
|
|
var prompt = chatHistory + "### Assistant:"; |
|
|
|
prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens)); |
|
|
|
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray(); |
|
|
|
var completion = StreamCall(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, |
|
|
|
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, |
|
|
|
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty); |
|
|
|
return ConvertTextCompletionChunksToChat(completion); |
|
|
|
} |
|
|
|
|