using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using Microsoft.Extensions.Logging;
namespace LLama
{
using llama_token = Int32;
///
/// This executor infer the input as one-time job. Previous inputs won't impact on the
/// response to current input.
///
public class StatelessExecutor
: ILLamaExecutor
{
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
///
/// The context used by the executor when running the inference.
///
public LLamaContext Context { get; private set; }
///
/// Create a new stateless executor which will use the given model
///
///
///
///
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
_weights = weights;
_params = @params;
_logger = logger;
Context = _weights.CreateContext(_params, logger);
Context.Dispose();
}
///
public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
// Create an inference context which will be disposed when this method exits
using var context = _weights.CreateContext(_params, _logger);
Context = context;
await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
{
yield return item;
}
}
public static async IAsyncEnumerable InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Sanity check inference params
inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");
// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);
// Tokenize the prompt
var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;
// Evaluate the prompt
await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
id = context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}
// Decode this token into text
var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
yield return decoded;
// Check if the generation should stop
if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
break;
lastTokens.Add(id);
tokens.Clear();
tokens.Add(id);
// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;
NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
}
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
}
}
}
}