using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
namespace LLama
{
using llama_token = Int32;
///
/// This executor infer the input as one-time job. Previous inputs won't impact on the
/// response to current input.
///
public class StatelessExecutor : ILLamaExecutor
{
private LLamaModel _model;
private byte[] _originalState;
///
/// The mode used by the executor when running the inference.
///
public LLamaModel Model => _model;
///
///
///
/// The LLama model.
public StatelessExecutor(LLamaModel model)
{
_model = model;
var tokens = model.Tokenize(" ", true);
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, tokens.Count(), 0, _model.Params.Threads);
_originalState = model.GetStateData();
}
///
public IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
int n_past = 1;
if(inferenceParams is null)
{
inferenceParams = new InferenceParams();
}
List lastTokens = new(inferenceParams.RepeatLastTokensCount);
for(int i = 0; i < lastTokens.Count; i++)
{
lastTokens[i] = 0;
}
List tokens = _model.Tokenize(text, true).ToList();
int n_prompt_tokens = tokens.Count;
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
if (cancellationToken.IsCancellationRequested)
{
_model.LoadState(_originalState);
break;
}
var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount;
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
lastTokens.Add(id);
string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding);
yield return response;
tokens.Clear();
tokens.Add(id);
if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0)
{
string last_output = "";
foreach (var token in lastTokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
}
bool should_break = false;
foreach (var antiprompt in inferenceParams.AntiPrompts)
{
if (last_output.EndsWith(antiprompt))
{
should_break = true;
break;
}
}
if (should_break)
{
break;
}
}
// when run out of context
if (n_past + tokens.Count > _model.ContextSize)
{
int n_left = n_past - inferenceParams.TokensKeep;
n_past = Math.Max(1, inferenceParams.TokensKeep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_model.ContextSize - n_left / 2 - tokens.Count));
}
n_past = _model.Eval(tokens.ToArray(), n_past);
}
_model.LoadState(_originalState);
}
///
public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
}
}