Browse Source

feat: add ILLamaExecutor.InferAsync.

tags/v0.4.0
Yaohui Liu 2 years ago
parent
commit
5679e08718
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
22 changed files with 450 additions and 292 deletions
  1. +4
    -4
      LLama.Examples/Old/ChatSession.cs
  2. +3
    -3
      LLama.Examples/Old/ChatWithLLamaModel.cs
  3. +3
    -3
      LLama.Examples/Old/GetEmbeddings.cs
  4. +3
    -3
      LLama.Examples/Old/InstructMode.cs
  5. +3
    -3
      LLama.Examples/Old/SaveAndLoadState.cs
  6. +3
    -3
      LLama.Examples/Program.cs
  7. +1
    -1
      LLama.WebAPI/Services/ChatService.cs
  8. +5
    -0
      LLama/Abstractions/Params/SessionParams.cs
  9. +4
    -1
      LLama/ILLamaExecutor.cs
  10. +80
    -0
      LLama/LLamaEmbedder.cs
  11. +119
    -1
      LLama/LLamaExecutorBase.cs
  12. +98
    -131
      LLama/LLamaInstructExecutor.cs
  13. +113
    -129
      LLama/LLamaInteractExecutor.cs
  14. +1
    -1
      LLama/LLamaModel.cs
  15. +3
    -2
      LLama/LLamaSharp.csproj
  16. +1
    -1
      LLama/OldVersion/ChatSession.cs
  17. +1
    -1
      LLama/OldVersion/IChatModel.cs
  18. +1
    -1
      LLama/OldVersion/LLamaEmbedder.cs
  19. +1
    -1
      LLama/OldVersion/LLamaModel.cs
  20. +1
    -1
      LLama/OldVersion/LLamaParams.cs
  21. +1
    -1
      LLama/OldVersion/LLamaTypes.cs
  22. +1
    -1
      LLama/OldVersion/Utils.cs

+ 4
- 4
LLama.Examples/Old/ChatSession.cs View File

@@ -3,17 +3,17 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Old;
using LLama.OldVersion;


namespace LLama.Examples namespace LLama.Examples
{ {
public class ChatSession public class ChatSession
{ {
LLama.Old.ChatSession<LLama.Old.LLamaModel> _session;
LLama.OldVersion.ChatSession<LLama.OldVersion.LLamaModel> _session;
public ChatSession(string modelPath, string promptFilePath, string[] antiprompt) public ChatSession(string modelPath, string promptFilePath, string[] antiprompt)
{ {
LLama.Old.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false));
_session = new ChatSession<LLama.Old.LLamaModel>(model)
LLama.OldVersion.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false));
_session = new ChatSession<LLama.OldVersion.LLamaModel>(model)
.WithPromptFile(promptFilePath) .WithPromptFile(promptFilePath)
.WithAntiprompt(antiprompt); .WithAntiprompt(antiprompt);
} }


+ 3
- 3
LLama.Examples/Old/ChatWithLLamaModel.cs View File

@@ -3,16 +3,16 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Old;
using LLama.OldVersion;


namespace LLama.Examples.Old namespace LLama.Examples.Old
{ {
public class ChatWithLLamaModel public class ChatWithLLamaModel
{ {
LLama.Old.LLamaModel _model;
LLama.OldVersion.LLamaModel _model;
public ChatWithLLamaModel(string modelPath, string promptFilePath, string[] antiprompt) public ChatWithLLamaModel(string modelPath, string promptFilePath, string[] antiprompt)
{ {
_model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(),
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(),
repeat_penalty: 1.0f)).WithPromptFile(promptFilePath); repeat_penalty: 1.0f)).WithPromptFile(promptFilePath);
} }




+ 3
- 3
LLama.Examples/Old/GetEmbeddings.cs View File

@@ -3,16 +3,16 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Old;
using LLama.OldVersion;


namespace LLama.Examples namespace LLama.Examples
{ {
public class GetEmbeddings public class GetEmbeddings
{ {
LLamaEmbedder _embedder;
LLama.OldVersion.LLamaEmbedder _embedder;
public GetEmbeddings(string modelPath) public GetEmbeddings(string modelPath)
{ {
_embedder = new LLamaEmbedder(new LLamaParams(model: modelPath));
_embedder = new LLama.OldVersion.LLamaEmbedder(new LLamaParams(model: modelPath));
} }


public void Run(string text) public void Run(string text)


+ 3
- 3
LLama.Examples/Old/InstructMode.cs View File

@@ -3,16 +3,16 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Old;
using LLama.OldVersion;


namespace LLama.Examples.Old namespace LLama.Examples.Old
{ {
public class InstructMode public class InstructMode
{ {
LLama.Old.LLamaModel _model;
LLama.OldVersion.LLamaModel _model;
public InstructMode(string modelPath, string promptFile) public InstructMode(string modelPath, string promptFile)
{ {
_model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPromptFile(promptFile); repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPromptFile(promptFile);
} }




+ 3
- 3
LLama.Examples/Old/SaveAndLoadState.cs View File

@@ -3,16 +3,16 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Old;
using LLama.OldVersion;


namespace LLama.Examples namespace LLama.Examples
{ {
public class SaveAndLoadState: IDisposable public class SaveAndLoadState: IDisposable
{ {
LLama.Old.LLamaModel _model;
LLama.OldVersion.LLamaModel _model;
public SaveAndLoadState(string modelPath, string prompt) public SaveAndLoadState(string modelPath, string prompt)
{ {
_model = new LLama.Old.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPrompt(prompt); repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPrompt(prompt);
} }




+ 3
- 3
LLama.Examples/Program.cs View File

@@ -26,12 +26,12 @@ if(version == 1)
Console.WriteLine("The examples for new versions are under working now. We'll soon update the examples." + Console.WriteLine("The examples for new versions are under working now. We'll soon update the examples." +
" Thank you for your support!"); " Thank you for your support!");
string modelPath = "D:\\development\\llama\\weights\\wizard-vicuna-13B.ggmlv3.q4_1.bin"; string modelPath = "D:\\development\\llama\\weights\\wizard-vicuna-13B.ggmlv3.q4_1.bin";
var prompt = File.ReadAllText("Assets/dan.txt").Trim();
LLamaInstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024)));
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
LLamaInteractExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337)));


while (true) while (true)
{ {
foreach (var text in ex.Infer(prompt, new SessionParams() { Temperature = 0.6f }))
await foreach (var text in ex.InferAsync(prompt, new SessionParams() { Temperature = 0.6f, AntiPrompts = new List<string>{ "user:" } }, default(CancellationToken)))
{ {
Console.Write(text); Console.Write(text);
} }


+ 1
- 1
LLama.WebAPI/Services/ChatService.cs View File

@@ -1,4 +1,4 @@
using LLama.Old;
using LLama.OldVersion;
using LLama.WebAPI.Models; using LLama.WebAPI.Models;


namespace LLama.WebAPI.Services; namespace LLama.WebAPI.Services;


+ 5
- 0
LLama/Abstractions/Params/SessionParams.cs View File

@@ -20,6 +20,11 @@ namespace LLama.Abstractions.Params
/// logit bias for specific tokens /// logit bias for specific tokens
/// </summary> /// </summary>
public Dictionary<llama_token, float>? LogitBias { get; set; } = null; public Dictionary<llama_token, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary> /// <summary>
/// path to file for saving/loading model eval state /// path to file for saving/loading model eval state
/// </summary> /// </summary>


+ 4
- 1
LLama/ILLamaExecutor.cs View File

@@ -2,11 +2,14 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using System.Threading;


namespace LLama namespace LLama
{ {
public interface ILLamaExecutor public interface ILLamaExecutor
{ {
IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null);
IEnumerable<string> Infer(string text, SessionParams? sessionParams = null);

IAsyncEnumerable<string> InferAsync(string text, SessionParams? sessionParams = null, CancellationToken token = default);
} }
} }

+ 80
- 0
LLama/LLamaEmbedder.cs View File

@@ -0,0 +1,80 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using LLama.Abstractions.Params;
using System.Linq;

namespace LLama
{
public class LLamaEmbedder : IDisposable
{
SafeLLamaContextHandle _ctx;

/// <summary>
/// Warning: must ensure the original model has params.embedding = true;
/// </summary>
/// <param name="ctx"></param>
internal LLamaEmbedder(SafeLLamaContextHandle ctx)
{
_ctx = ctx;
}

public LLamaEmbedder(ModelParams @params)
{
@params.EmbeddingMode = true;
_ctx = Utils.InitLLamaContextFromModelParams(@params);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="threads">Threads used for inference.</param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="encoding"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
{
if (threads == -1)
{
threads = Math.Max(Environment.ProcessorCount / 2, 1);
}
int n_past = 0;
if (addBos)
{
text = text.Insert(0, " ");
}
var embed_inp = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding));

// TODO(Rinne): deal with log of prompt

if (embed_inp.Count() > 0)
{
var embed_inp_array = embed_inp.ToArray();
if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0)
{
throw new RuntimeError("Failed to eval.");
}
}

int n_embed = NativeApi.llama_n_embd(_ctx);
var embeddings = NativeApi.llama_get_embeddings(_ctx);
if (embeddings == null)
{
return new float[0];
}
var span = new Span<float>(embeddings, n_embed);
float[] res = new float[n_embed];
span.CopyTo(res.AsSpan());
return res;
}

public void Dispose()
{
_ctx.Dispose();
}
}
}

+ 119
- 1
LLama/LLamaExecutorBase.cs View File

@@ -6,7 +6,10 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks;


namespace LLama namespace LLama
{ {
@@ -106,6 +109,121 @@ namespace LLama
} }
} }


public abstract IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null);
protected abstract bool GetLoopCondition(InferStateArgs args);
protected abstract void PreprocessInputs(string text, InferStateArgs args);
protected abstract bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
protected abstract void InferInternal(SessionParams sessionParams, InferStateArgs args);
public virtual IEnumerable<string> Infer(string text, SessionParams? sessionParams = null)
{
if (sessionParams is null)
{
sessionParams = new SessionParams();
}

InferStateArgs args = new InferStateArgs()
{
Antiprompts = sessionParams.AntiPrompts,
RemainedTokens = sessionParams.ResponseTokensCount,
ReturnValue = false,
WaitForInput = false,
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};

PreprocessInputs(text, args);

while (GetLoopCondition(args))
{
InferInternal(sessionParams, args);

if (args.ReturnValue)
{
foreach (var item in _model.GenerateResult(_embeds))
{
yield return item;
}
}

var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs);
if (extraOutputs is not null)
{
foreach (var item in extraOutputs)
{
yield return item;
}
}
if (breakGeneration)
{
break;
}
}
}
public virtual async IAsyncEnumerable<string> InferAsync(string text, SessionParams? sessionParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
// make this delay only to make the async method consistent with what it's expected to be
//await Task.Delay(1);

if (sessionParams is null)
{
sessionParams = new SessionParams();
}

InferStateArgs args = new InferStateArgs()
{
Antiprompts = sessionParams.AntiPrompts,
RemainedTokens = sessionParams.ResponseTokensCount,
ReturnValue = false,
WaitForInput = false,
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};

PreprocessInputs(text, args);

while (GetLoopCondition(args))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}

InferInternal(sessionParams, args);

if (args.ReturnValue)
{
foreach (var item in _model.GenerateResult(_embeds))
{
yield return item;
}
}

var breakGeneration = PostProcess(sessionParams, args, out var extraOutputs);
if (extraOutputs is not null)
{
foreach (var item in extraOutputs)
{
yield return item;
}
}
if (breakGeneration)
{
break;
}
}
}

/// <summary>
/// State arguments that are used in single inference
/// </summary>
protected class InferStateArgs
{
public IList<string>? Antiprompts { get; set; }
/// <summary>
/// Tokens count remained to be used. (n_remain)
/// </summary>
public int RemainedTokens { get; set; }
public bool ReturnValue { get; set; }
public bool WaitForInput { get; set; }
public bool NeedToSaveSession { get; set; }
}
} }
} }

+ 98
- 131
LLama/LLamaInstructExecutor.cs View File

@@ -11,28 +11,28 @@ namespace LLama
public class LLamaInstructExecutor : LLamaExecutorBase public class LLamaInstructExecutor : LLamaExecutorBase
{ {
bool _prompt_run = true; bool _prompt_run = true;
readonly IEnumerable<llama_token> _llama_token_newline;
readonly IEnumerable<llama_token> _inp_pfx; readonly IEnumerable<llama_token> _inp_pfx;
readonly IEnumerable<llama_token> _inp_sfx; readonly IEnumerable<llama_token> _inp_sfx;
public LLamaInstructExecutor(LLamaModel model, string inputPrefix = "\n\n### Instruction:\n\n", public LLamaInstructExecutor(LLamaModel model, string inputPrefix = "\n\n### Instruction:\n\n",
string inputSuffix = "\n\n### Response:\n\n") : base(model) string inputSuffix = "\n\n### Response:\n\n") : base(model)
{ {
_llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding);
_inp_pfx = _model.Tokenize(inputPrefix, true); _inp_pfx = _model.Tokenize(inputPrefix, true);
_inp_sfx = _model.Tokenize(inputSuffix, false); _inp_sfx = _model.Tokenize(inputSuffix, false);
} }


/// <summary>
/// process the text and return the tokens consumed.
/// </summary>
/// <param name="text"></param>
/// <param name="sessionParams"></param>
/// <param name="encoding"></param>
/// <param name="is_antiprompt"></param>
/// <returns></returns>
protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams)
protected override bool GetLoopCondition(InferStateArgs args)
{ {
if (text.Length > 1)
return args.RemainedTokens != 0 || _prompt_run;
}
protected override void PreprocessInputs(string text, InferStateArgs args)
{
if (_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
text = " " + text;
_embed_inps = _model.Tokenize(text, true).ToList();
}
else
{ {
if (!text.EndsWith("\n")) if (!text.EndsWith("\n"))
{ {
@@ -46,153 +46,120 @@ namespace LLama


_embed_inps.AddRange(_inp_sfx); _embed_inps.AddRange(_inp_sfx);


return line_inp.Count();
}
else
{
return 0;
args.RemainedTokens -= line_inp.Count();
} }
} }

public override IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null)
protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
{ {
if (sessionParams is null)
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{ {
sessionParams = new SessionParams();
}
// if n_remain < 0, the response will be generated endlessly.
int n_remain = sessionParams.ResponseTokensCount;
bool return_value = false;
bool wait_for_input = false;
bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count;

if (_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
text = " " + text;
_embed_inps = _model.Tokenize(text, true).ToList();
}
else
{
n_remain -= ProcessTextBeforeInfer(text, sessionParams);
}

while (n_remain != 0 || _prompt_run)
{
if (_embeds.Count > 0)
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{ {
_prompt_run = false;
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
string last_output = "";
foreach (var id in _last_n_tokens)
{ {
HandleRunOutOfContext(sessionParams.TokensToKeep);
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
} }


TryReuseMathingPrefix();
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
foreach (var antiprompt in args.Antiprompts)
{ {
_session_tokens.AddRange(_embeds);
_n_session_consumed = _session_tokens.Count;
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
return true;
}
} }
} }


_embeds.Clear();

if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input)
if (_pastTokensCount > 0 && args.WaitForInput)
{ {
var temp = sessionParams.Temperature;
var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK;
var top_p = sessionParams.TopK;
var tfs_z = sessionParams.TfsZ;
var typical_p = sessionParams.TypicalP;
var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount;
var repeat_penalty = sessionParams.RepeatPenalty;
var alpha_presence = sessionParams.PresencePenalty;
var alpha_frequency = sessionParams.FrequencyPenalty;
var mirostat = sessionParams.Mirostat;
var mirostat_tau = sessionParams.MirostatTau;
var mirostat_eta = sessionParams.MirostatEta;
var penalize_nl = sessionParams.PenalizeNL;

// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session)
{
need_to_save_session = false;
SaveSessionFile(_pathSession);
}

var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n,
repeat_penalty, alpha_frequency, alpha_presence, penalize_nl);
extraOutputs = new string[] { "\n> " };
return true;
}
}


var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p,
tfs_z, typical_p);
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
{
args.WaitForInput = true;
}


_last_n_tokens.Enqueue(id);
if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1)
{
args.RemainedTokens = sessionParams.ResponseTokensCount;
args.WaitForInput = true;
}
return false;
}
protected override void InferInternal(SessionParams sessionParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
_prompt_run = false;
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
{
HandleRunOutOfContext(sessionParams.TokensToKeep);
}


_embeds.Add(id);
TryReuseMathingPrefix();
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);


n_remain--;
return_value = true;
}
else
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{ {
while (_embed_inps.Count > _consumedTokensCount)
{
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= _model.Params.BatchSize)
{
break;
}
}
_session_tokens.AddRange(_embeds);
_n_session_consumed = _session_tokens.Count;
} }
}

_embeds.Clear();


if (return_value)
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
var temp = sessionParams.Temperature;
var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK;
var top_p = sessionParams.TopK;
var tfs_z = sessionParams.TfsZ;
var typical_p = sessionParams.TypicalP;
var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount;
var repeat_penalty = sessionParams.RepeatPenalty;
var alpha_presence = sessionParams.PresencePenalty;
var alpha_frequency = sessionParams.FrequencyPenalty;
var mirostat = sessionParams.Mirostat;
var mirostat_tau = sessionParams.MirostatTau;
var mirostat_eta = sessionParams.MirostatEta;
var penalize_nl = sessionParams.PenalizeNL;

// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{ {
foreach (var item in _model.GenerateResult(_embeds))
{
yield return item;
}
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);
} }


if (_embed_inps.Count <= _consumedTokensCount)
{
if (antiprompts is not null && antiprompts.Count() > 0)
{
string last_output = "";
foreach (var id in _last_n_tokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
}
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n,
repeat_penalty, alpha_frequency, alpha_presence, penalize_nl);


foreach (var antiprompt in antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
wait_for_input = true;
break;
}
}
}
var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p,
tfs_z, typical_p);


if (_pastTokensCount > 0 && wait_for_input)
{
yield return "\n> ";
break;
}
}
_last_n_tokens.Enqueue(id);


if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
{
wait_for_input = true;
}
_embeds.Add(id);


if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1)
args.RemainedTokens--;
args.ReturnValue = true;
}
else
{
while (_embed_inps.Count > _consumedTokensCount)
{ {
n_remain = sessionParams.ResponseTokensCount;
wait_for_input = true;
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= _model.Params.BatchSize)
{
break;
}
} }
} }
} }


+ 113
- 129
LLama/LLamaInteractExecutor.cs View File

@@ -3,7 +3,10 @@ using LLama.Native;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks;


namespace LLama namespace LLama
{ {
@@ -22,43 +25,16 @@ namespace LLama
} }


/// <summary> /// <summary>
/// process the text and return the tokens consumed.
/// Define whether to continue the loop to generate responses.
/// </summary> /// </summary>
/// <param name="text"></param>
/// <param name="sessionParams"></param>
/// <param name="encoding"></param>
/// <param name="is_antiprompt"></param>
/// <returns></returns> /// <returns></returns>
protected virtual int ProcessTextBeforeInfer(string text, SessionParams sessionParams)
protected override bool GetLoopCondition(InferStateArgs args)
{ {
if (text.Length > 1)
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
return line_inp.Count();
}
else
{
return 0;
}
return args.RemainedTokens != 0 && !args.WaitForInput || _prompt_run;
} }


public override IEnumerable<string> Infer(string text, SessionParams? sessionParams = null, IEnumerable<string>? antiprompts = null)
protected override void PreprocessInputs(string text, InferStateArgs args)
{ {
if (sessionParams is null)
{
sessionParams = new SessionParams();
}
// if n_remain < 0, the response will be generated endlessly.
int n_remain = sessionParams.ResponseTokensCount;
bool return_value = false;
bool wait_for_input = false;
bool need_to_save_session = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count;

if (_prompt_run) if (_prompt_run)
{ {
// When running the first input (prompt) in inteactive mode, we should specially process it. // When running the first input (prompt) in inteactive mode, we should specially process it.
@@ -67,135 +43,143 @@ namespace LLama
} }
else else
{ {
n_remain -= ProcessTextBeforeInfer(text, sessionParams);
if (!text.EndsWith("\n"))
{
text += "\n";
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Count();
} }
}


while (n_remain != 0 && !wait_for_input || _prompt_run)
/// <summary>
/// Return whether to break the generation.
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
protected override bool PostProcess(SessionParams sessionParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{ {
if (_embeds.Count > 0)
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{ {
_prompt_run = false;
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
string last_output = "";
foreach (var id in _last_n_tokens)
{ {
HandleRunOutOfContext(sessionParams.TokensToKeep);
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
} }


TryReuseMathingPrefix();
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
foreach (var antiprompt in args.Antiprompts)
{ {
_session_tokens.AddRange(_embeds);
_n_session_consumed = _session_tokens.Count;
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
break;
}
} }
} }


_embeds.Clear();

if (_embed_inps.Count <= _consumedTokensCount && !wait_for_input)
if (_pastTokensCount > 0 && args.WaitForInput)
{ {
var temp = sessionParams.Temperature;
var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK;
var top_p = sessionParams.TopK;
var tfs_z = sessionParams.TfsZ;
var typical_p = sessionParams.TypicalP;
var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount;
var repeat_penalty = sessionParams.RepeatPenalty;
var alpha_presence = sessionParams.PresencePenalty;
var alpha_frequency = sessionParams.FrequencyPenalty;
var mirostat = sessionParams.Mirostat;
var mirostat_tau = sessionParams.MirostatTau;
var mirostat_eta = sessionParams.MirostatEta;
var penalize_nl = sessionParams.PenalizeNL;

// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && need_to_save_session)
{
need_to_save_session = false;
SaveSessionFile(_pathSession);
}

var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n,
repeat_penalty, alpha_frequency, alpha_presence, penalize_nl);
return true;
}
}


var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p,
tfs_z, typical_p);
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
{
extraOutputs = new string[] { " [end of text]\n" };
return true;
}


_last_n_tokens.Enqueue(id);
if (args.RemainedTokens <= 0 && sessionParams.ResponseTokensCount != -1)
{
args.RemainedTokens = sessionParams.ResponseTokensCount;
args.WaitForInput = true;
}
return false;
}


if (id == NativeApi.llama_token_eos())
{
id = _llama_token_newline.First();
if (antiprompts is not null && antiprompts.Count() > 0)
{
var first_antiprompt = _model.Tokenize(antiprompts.First(), false);
_embed_inps.AddRange(first_antiprompt);
}
}
protected override void InferInternal(SessionParams sessionParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
_prompt_run = false;
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
{
HandleRunOutOfContext(sessionParams.TokensToKeep);
}


_embeds.Add(id);
TryReuseMathingPrefix();
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);


n_remain--;
return_value = true;
}
else
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{ {
while (_embed_inps.Count > _consumedTokensCount)
{
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= _model.Params.BatchSize)
{
break;
}
}
_session_tokens.AddRange(_embeds);
_n_session_consumed = _session_tokens.Count;
} }
}

_embeds.Clear();


if (return_value)
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
var temp = sessionParams.Temperature;
var top_k = sessionParams.TopK <= 0 ? NativeApi.llama_n_vocab(_model.NativeHandle) : sessionParams.TopK;
var top_p = sessionParams.TopK;
var tfs_z = sessionParams.TfsZ;
var typical_p = sessionParams.TypicalP;
var repeat_last_n = sessionParams.RepeatLastTokensCount < 0 ? _model.ContextSize : sessionParams.RepeatLastTokensCount;
var repeat_penalty = sessionParams.RepeatPenalty;
var alpha_presence = sessionParams.PresencePenalty;
var alpha_frequency = sessionParams.FrequencyPenalty;
var mirostat = sessionParams.Mirostat;
var mirostat_tau = sessionParams.MirostatTau;
var mirostat_eta = sessionParams.MirostatEta;
var penalize_nl = sessionParams.PenalizeNL;

// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{ {
foreach (var item in _model.GenerateResult(_embeds))
{
yield return item;
}
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);
} }


if (_embed_inps.Count <= _consumedTokensCount)
{
if (antiprompts is not null && antiprompts.Count() > 0)
{
string last_output = "";
foreach (var id in _last_n_tokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
}
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, sessionParams.LogitBias, repeat_last_n,
repeat_penalty, alpha_frequency, alpha_presence, penalize_nl);


foreach (var antiprompt in antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
wait_for_input = true;
break;
}
}
}
var id = _model.Sample(tokenDataArray, temp, mirostat, mirostat_tau, mirostat_eta, top_k, top_p,
tfs_z, typical_p);


if (_pastTokensCount > 0 && wait_for_input)
_last_n_tokens.Enqueue(id);

if (id == NativeApi.llama_token_eos())
{
id = _llama_token_newline.First();
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{ {
break;
var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false);
_embed_inps.AddRange(first_antiprompt);
} }
} }


if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
{
yield return " [end of text]\n";
break;
}
_embeds.Add(id);


if (n_remain <= 0 && sessionParams.ResponseTokensCount != -1)
args.RemainedTokens--;
args.ReturnValue = true;
}
else
{
while (_embed_inps.Count > _consumedTokensCount)
{ {
n_remain = sessionParams.ResponseTokensCount;
wait_for_input = true;
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= _model.Params.BatchSize)
{
break;
}
} }
} }
} }


+ 1
- 1
LLama/LLamaModel.cs View File

@@ -1,7 +1,7 @@
using LLama.Abstractions.Params; using LLama.Abstractions.Params;
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Native; using LLama.Native;
using LLama.Old;
using LLama.OldVersion;
using LLama.Types; using LLama.Types;
using LLama.Extensions; using LLama.Extensions;
using System; using System;


+ 3
- 2
LLama/LLamaSharp.csproj View File

@@ -8,7 +8,7 @@
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks> <AllowUnsafeBlocks>True</AllowUnsafeBlocks>


<Version>0.3.0</Version>
<Version>0.4.0</Version>
<Authors>Yaohui Liu, Haiping Chen</Authors> <Authors>Yaohui Liu, Haiping Chen</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -21,7 +21,7 @@
The .NET binding of LLama.cpp, providing APIs to run the model and deploy it on Web. For model weights to run, please go to https://github.com/SciSharp/LLamaSharp for more information. The .NET binding of LLama.cpp, providing APIs to run the model and deploy it on Web. For model weights to run, please go to https://github.com/SciSharp/LLamaSharp for more information.
</Description> </Description>
<PackageReleaseNotes> <PackageReleaseNotes>
LLamaSharp 0.3.0 supports loading and saving session state, tokenization and detokenization. Besides, since 0.3.0, `LLamaModelV1` is dropped.
LLamaSharp 0.4.0 supports better APIs than v0.3.0. Note that many break changes were made in this version. APIs of v0.3.0 were moved to LLama.Old namespace.
</PackageReleaseNotes> </PackageReleaseNotes>
<PackageLicenseExpression>MIT</PackageLicenseExpression> <PackageLicenseExpression>MIT</PackageLicenseExpression>
<PackageOutputPath>packages</PackageOutputPath> <PackageOutputPath>packages</PackageOutputPath>
@@ -41,6 +41,7 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'"> <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" /> <PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" />
<PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" /> <PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" />
<PackageReference Include="System.Linq.Async" VersionOverride="[6.0.1, )" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


LLama/Old/ChatSession.cs → LLama/OldVersion/ChatSession.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Text; using System.Text;


namespace LLama.Old
namespace LLama.OldVersion
{ {
public class ChatSession<T> where T : IChatModel public class ChatSession<T> where T : IChatModel
{ {

LLama/Old/IChatModel.cs → LLama/OldVersion/IChatModel.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace LLama.Old
namespace LLama.OldVersion
{ {
public interface IChatModel public interface IChatModel
{ {

LLama/Old/LLamaEmbedder.cs → LLama/OldVersion/LLamaEmbedder.cs View File

@@ -4,7 +4,7 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using LLama.Exceptions; using LLama.Exceptions;


namespace LLama.Old
namespace LLama.OldVersion
{ {
public class LLamaEmbedder : IDisposable public class LLamaEmbedder : IDisposable
{ {

LLama/Old/LLamaModel.cs → LLama/OldVersion/LLamaModel.cs View File

@@ -9,7 +9,7 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;


namespace LLama.Old
namespace LLama.OldVersion
{ {
using llama_token = Int32; using llama_token = Int32;
public class LLamaModel : IChatModel, IDisposable public class LLamaModel : IChatModel, IDisposable

LLama/Old/LLamaParams.cs → LLama/OldVersion/LLamaParams.cs View File

@@ -1,7 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;


namespace LLama.Old
namespace LLama.OldVersion
{ {
using llama_token = Int32; using llama_token = Int32;
public struct LLamaParams public struct LLamaParams

LLama/Old/LLamaTypes.cs → LLama/OldVersion/LLamaTypes.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace LLama.Old
namespace LLama.OldVersion
{ {
public enum ChatRole public enum ChatRole
{ {

LLama/Old/Utils.cs → LLama/OldVersion/Utils.cs View File

@@ -8,7 +8,7 @@ using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.IO; using System.IO;


namespace LLama.Old
namespace LLama.OldVersion
{ {
using llama_token = Int32; using llama_token = Int32;
internal static class Utils internal static class Utils

Loading…
Cancel
Save