using LLama.Abstractions; using LLama.Common; using LLama.Native; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using LLama.Extensions; using Microsoft.Extensions.Logging; namespace LLama { using llama_token = Int32; /// /// The LLama executor for instruct mode. /// public class InstructExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; private readonly string _instructionPrefix; private llama_token[] _inp_pfx; private llama_token[] _inp_sfx; /// /// /// /// /// /// /// public InstructExecutor(LLamaContext context, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n", ILogger? logger = null) : base(context, logger) { _inp_pfx = Context.Tokenize(instructionPrefix, true); _inp_sfx = Context.Tokenize(instructionSuffix, false); _instructionPrefix = instructionPrefix; } /// public override ExecutorBaseState GetStateData() { InstructExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, EmbedInps = _embed_inps, IsPromptRun = _is_prompt_run, ConsumedTokensCount = _consumedTokensCount, Embeds = _embeds, LastTokens = _last_n_tokens.ToArray(), InputPrefixTokens = _inp_pfx, InputSuffixTokens = _inp_sfx, MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, SessionTokens = _session_tokens, LastTokensCapacity = _last_n_tokens.Capacity, MirostatMu = MirostatMu }; return state; } /// public override Task LoadState(ExecutorBaseState data) { if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps; _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; _embeds = state.Embeds; _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _inp_pfx = state.InputPrefixTokens; _inp_sfx = state.InputSuffixTokens; _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; _session_tokens = state.SessionTokens; } else { throw new ArgumentException("Invalid state data type."); } return Task.CompletedTask; } /// public override async Task SaveState(string filename) { var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { await JsonSerializer.SerializeAsync(fs, state); } } /// public override async Task LoadState(string filename) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); await LoadState(state); } } /// protected override Task GetLoopCondition(InferStateArgs args) { return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); } /// protected override Task PreprocessInputs(string text, InferStateArgs args) { args.Antiprompts ??= new List(); args.Antiprompts.Add(_instructionPrefix); if (_is_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. _embed_inps = Context.Tokenize(text, true).ToList(); } else { if (!text.EndsWith("\n")) { text += "\n"; } _consumedTokensCount = _embed_inps.Count; _embed_inps.AddRange(_inp_pfx); var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); _embed_inps.AddRange(_inp_sfx); args.RemainedTokens -= line_inp.Length; } return Task.CompletedTask; } /// protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { args.WaitForInput = true; return (true, Array.Empty()); } if (_pastTokensCount > 0 && args.WaitForInput) { return (true, new[] { "\n> " }); } } if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { args.WaitForInput = true; } if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) { args.RemainedTokens = inferenceParams.MaxTokens; args.WaitForInput = true; } return (false, Array.Empty()); } /// protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { _is_prompt_run = false; if (_pastTokensCount + _embeds.Count > Context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { _session_tokens.AddRange(_embeds); _n_session_consumed = _session_tokens.Count; } } _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { args.NeedToSaveSession = false; SaveSessionFile(_pathSession); } var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, inferenceParams.MinP ); MirostatMu = mu; _last_n_tokens.Enqueue(id); _embeds.Add(id); args.RemainedTokens--; args.ReturnValue = true; } else { while (_embed_inps.Count > _consumedTokensCount) { _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; if (_embeds.Count >= Context.Params.BatchSize) { break; } } } return Task.CompletedTask; } /// /// The desciptor of the state of the instruct executor. /// public class InstructExecutorState : ExecutorBaseState { /// /// Whether the executor is running for the first time (running the prompt). /// [JsonPropertyName("is_prompt_run")] public bool IsPromptRun { get; set; } /// /// Instruction prefix tokens. /// [JsonPropertyName("inp_pfx")] public llama_token[] InputPrefixTokens { get; set; } /// /// Instruction suffix tokens. /// [JsonPropertyName("inp_sfx")] public llama_token[] InputSuffixTokens { get; set; } } } }