using LLama.Abstractions; using LLama.Common; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using LLama.Extensions; namespace LLama { using llama_token = Int32; /// /// This executor infer the input as one-time job. Previous inputs won't impact on the /// response to current input. /// public class StatelessExecutor : ILLamaExecutor { private readonly LLamaWeights _weights; private readonly IModelParams _params; /// /// The context used by the executor when running the inference. /// public LLamaContext Context { get; private set; } /// /// Create a new stateless executor which will use the given model /// /// /// public StatelessExecutor(LLamaWeights weights, IModelParams @params) { _weights = weights; _params = @params; Context = _weights.CreateContext(_params); Context.Dispose(); } /// /// Create a new stateless executor which will use the model used to create the given context /// /// [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] public StatelessExecutor(LLamaContext context) { _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); _params = context.Params; Context = _weights.CreateContext(_params); Context.Dispose(); } /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { using var context = _weights.CreateContext(_params); Context = context; if (!Context.NativeHandle.IsClosed) Context.Dispose(); Context = _weights.CreateContext(Context.Params); if (inferenceParams != null) { if (inferenceParams.TokensKeep > Context.ContextSize) throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); } cancellationToken.ThrowIfCancellationRequested(); var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); var n_past = 1; inferenceParams ??= new InferenceParams(); var lastTokens = new List(inferenceParams.RepeatLastTokensCount); for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) lastTokens.Add(0); var tokens = Context.Tokenize(text).ToList(); var n_prompt_tokens = tokens.Count; Context.Eval(tokens, n_past); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; var mu = (float?)null; var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens; i++) { if (cancellationToken.IsCancellationRequested) break; var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); var response = Context.TokenToString(id); yield return response; tokens.Clear(); tokens.Add(id); if (EndsWithAntiprompt(lastTokens, antiprompts)) break; // when run out of context // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 if (n_past + tokens.Count > Context.ContextSize) { var n_left = n_past - inferenceParams.TokensKeep; n_past = Math.Max(1, inferenceParams.TokensKeep); tokens.Clear(); tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); } n_past = Context.Eval(tokens, n_past); } } /// /// Check if the given tokens list ends with any of the antiprompts /// /// /// /// private bool EndsWithAntiprompt(IReadOnlyList tokens, IReadOnlyList antiprompts) { return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); } } }