| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class ChatSessionStripRoleName | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class ChatSessionWithRoleName | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class GrammarJsonResponse | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||
| var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim(); | |||
| var grammar = Grammar.Parse(gbnf, "root"); | |||
| Console.Write("Please input your model path: "); | |||
| @@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write("Answer: "); | |||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | |||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class InstructModeExecute | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion | |||
| while (true) | |||
| { | |||
| foreach (var text in executor.Infer(prompt, inferenceParams)) | |||
| await foreach (var text in executor.InferAsync(prompt, inferenceParams)) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class SaveAndLoadSession | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||
| Console.Write(prompt); | |||
| while (true) | |||
| { | |||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class LoadAndSaveState | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||
| while (true) | |||
| { | |||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public class StatelessModeExecute | |||
| { | |||
| public static void Run() | |||
| public static async Task Run() | |||
| { | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write("Answer: "); | |||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | |||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||
| { | |||
| Console.Write(text); | |||
| } | |||
| @@ -29,11 +29,11 @@ | |||
| if (choice == 0) | |||
| { | |||
| ChatSessionWithRoleName.Run(); | |||
| await ChatSessionWithRoleName.Run(); | |||
| } | |||
| else if (choice == 1) | |||
| { | |||
| ChatSessionStripRoleName.Run(); | |||
| await ChatSessionStripRoleName.Run(); | |||
| } | |||
| else if(choice == 2) | |||
| { | |||
| @@ -41,19 +41,19 @@ | |||
| } | |||
| else if(choice == 3) | |||
| { | |||
| InstructModeExecute.Run(); | |||
| await InstructModeExecute.Run(); | |||
| } | |||
| else if(choice == 4) | |||
| { | |||
| StatelessModeExecute.Run(); | |||
| await StatelessModeExecute.Run(); | |||
| } | |||
| else if(choice == 5) | |||
| { | |||
| SaveAndLoadSession.Run(); | |||
| await SaveAndLoadSession.Run(); | |||
| } | |||
| else if(choice == 6) | |||
| { | |||
| LoadAndSaveState.Run(); | |||
| await LoadAndSaveState.Run(); | |||
| } | |||
| else if(choice == 7) | |||
| { | |||
| @@ -69,7 +69,7 @@ | |||
| } | |||
| else if (choice == 10) | |||
| { | |||
| GrammarJsonResponse.Run(); | |||
| await GrammarJsonResponse.Run(); | |||
| } | |||
| else if (choice == 11) | |||
| { | |||
| @@ -41,7 +41,7 @@ namespace LLama.Unittest | |||
| } | |||
| [Fact] | |||
| public void SampleWithTrivialGrammar() | |||
| public async Task SampleWithTrivialGrammar() | |||
| { | |||
| // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | |||
| // we can be confident it's not what the LLM would say if not constrained by the grammar! | |||
| @@ -66,7 +66,7 @@ namespace LLama.Unittest | |||
| Grammar = grammar, | |||
| }; | |||
| var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); | |||
| var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); | |||
| Assert.Equal("cat", result[0]); | |||
| } | |||
| @@ -12,6 +12,7 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | |||
| <PackageReference Include="System.Linq.Async" Version="6.0.1" /> | |||
| <PackageReference Include="xunit" Version="2.5.0" /> | |||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | |||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
| @@ -27,15 +27,15 @@ namespace LLama.Unittest | |||
| } | |||
| [Fact] | |||
| public void Stateless() | |||
| public async Task Stateless() | |||
| { | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| const string question = "Question. what is a cat?\nAnswer: "; | |||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | |||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||
| var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| _testOutputHelper.WriteLine(result1); | |||
| @@ -44,7 +44,7 @@ namespace LLama.Unittest | |||
| } | |||
| [Fact] | |||
| public void OutOfContext() | |||
| public async Task OutOfContext() | |||
| { | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| @@ -58,8 +58,8 @@ namespace LLama.Unittest | |||
| TokensKeep = question.Length, | |||
| }; | |||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||
| var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||
| _testOutputHelper.WriteLine(result1); | |||
| @@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers | |||
| } | |||
| [HttpPost("Send")] | |||
| public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) | |||
| public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) | |||
| { | |||
| return _service.Send(input); | |||
| } | |||
| @@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable | |||
| _context?.Dispose(); | |||
| } | |||
| public string Send(SendMessageInput input) | |||
| public async Task<string> Send(SendMessageInput input) | |||
| { | |||
| var userInput = input.Text; | |||
| if (!_continue) | |||
| @@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable | |||
| Console.Write(input.Text); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| var outputs = _session.Chat(userInput, new Common.InferenceParams() | |||
| var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() | |||
| { | |||
| RepeatPenalty = 1.0f, | |||
| AntiPrompts = new string[] { "User:" }, | |||
| }); | |||
| var result = ""; | |||
| foreach (var output in outputs) | |||
| await foreach (var output in outputs) | |||
| { | |||
| Console.Write(output); | |||
| result += output; | |||
| @@ -13,15 +13,6 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| public LLamaContext Context { get; } | |||
| /// <summary> | |||
| /// Infers a response from the model. | |||
| /// </summary> | |||
| /// <param name="text">Your prompt</param> | |||
| /// <param name="inferenceParams">Any additional parameters</param> | |||
| /// <param name="token">A cancellation token.</param> | |||
| /// <returns></returns> | |||
| IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||
| /// <summary> | |||
| /// Asynchronously infers a response from the model. | |||
| /// </summary> | |||
| @@ -7,13 +7,6 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| public interface ITextStreamTransform | |||
| { | |||
| /// <summary> | |||
| /// Takes a stream of tokens and transforms them, returning a new stream of tokens. | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| /// <returns></returns> | |||
| IEnumerable<string> Transform(IEnumerable<string> tokens); | |||
| /// <summary> | |||
| /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. | |||
| /// </summary> | |||
| @@ -134,26 +134,6 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Get the response from the LLama model with chat histories. | |||
| /// </summary> | |||
| /// <param name="history"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var prompt = HistoryTransform.HistoryToText(history); | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||
| StringBuilder sb = new(); | |||
| foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| } | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||
| } | |||
| /// <summary> | |||
| /// Get the response from the LLama model. Note that prompt could not only be the preset words, | |||
| /// but also the question you want to ask. | |||
| @@ -162,15 +142,14 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach(var inputTransform in InputTransformPipeline) | |||
| { | |||
| prompt = inputTransform.Transform(prompt); | |||
| } | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||
| StringBuilder sb = new(); | |||
| foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| @@ -198,35 +177,6 @@ namespace LLama | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||
| } | |||
| /// <summary> | |||
| /// Get the response from the LLama model with chat histories asynchronously. | |||
| /// </summary> | |||
| /// <param name="prompt"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var inputTransform in InputTransformPipeline) | |||
| { | |||
| prompt = inputTransform.Transform(prompt); | |||
| } | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||
| StringBuilder sb = new(); | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| } | |||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||
| } | |||
| private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var results = _executor.Infer(prompt, inferenceParams, cancellationToken); | |||
| return OutputTransform.Transform(results); | |||
| } | |||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | |||
| @@ -10,6 +10,7 @@ using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text.Json.Serialization; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -212,47 +213,53 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="args"></param> | |||
| /// <returns></returns> | |||
| protected abstract bool GetLoopCondition(InferStateArgs args); | |||
| protected abstract Task<bool> GetLoopCondition(InferStateArgs args); | |||
| /// <summary> | |||
| /// Preprocess the inputs before the inference. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="args"></param> | |||
| protected abstract void PreprocessInputs(string text, InferStateArgs args); | |||
| protected abstract Task PreprocessInputs(string text, InferStateArgs args); | |||
| /// <summary> | |||
| /// Do some post processing after the inference. | |||
| /// </summary> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="args"></param> | |||
| /// <param name="extraOutputs"></param> | |||
| /// <returns></returns> | |||
| protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); | |||
| protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); | |||
| /// <summary> | |||
| /// The core inference logic. | |||
| /// </summary> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="args"></param> | |||
| protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); | |||
| protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); | |||
| /// <summary> | |||
| /// Save the current state to a file. | |||
| /// </summary> | |||
| /// <param name="filename"></param> | |||
| public abstract void SaveState(string filename); | |||
| public abstract Task SaveState(string filename); | |||
| /// <summary> | |||
| /// Get the current state data. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public abstract ExecutorBaseState GetStateData(); | |||
| /// <summary> | |||
| /// Load the state from data. | |||
| /// </summary> | |||
| /// <param name="data"></param> | |||
| public abstract void LoadState(ExecutorBaseState data); | |||
| public abstract Task LoadState(ExecutorBaseState data); | |||
| /// <summary> | |||
| /// Load the state from a file. | |||
| /// </summary> | |||
| /// <param name="filename"></param> | |||
| public abstract void LoadState(string filename); | |||
| public abstract Task LoadState(string filename); | |||
| /// <summary> | |||
| @@ -262,12 +269,12 @@ namespace LLama | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| cancellationToken.ThrowIfCancellationRequested(); | |||
| inferenceParams ??= new InferenceParams(); | |||
| InferStateArgs args = new InferStateArgs() | |||
| var args = new InferStateArgs | |||
| { | |||
| Antiprompts = inferenceParams.AntiPrompts.ToList(), | |||
| RemainedTokens = inferenceParams.MaxTokens, | |||
| @@ -276,15 +283,15 @@ namespace LLama | |||
| NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count | |||
| }; | |||
| PreprocessInputs(text, args); | |||
| await PreprocessInputs(text, args); | |||
| while (GetLoopCondition(args)) | |||
| while (await GetLoopCondition(args)) | |||
| { | |||
| if (cancellationToken.IsCancellationRequested) | |||
| { | |||
| break; | |||
| } | |||
| InferInternal(inferenceParams, args); | |||
| await InferInternal(inferenceParams, args); | |||
| if (args.ReturnValue) | |||
| { | |||
| @@ -292,8 +299,8 @@ namespace LLama | |||
| yield return Context.TokenToString(id); | |||
| } | |||
| var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs); | |||
| if (extraOutputs is not null) | |||
| var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); | |||
| if (extraOutputs is { Count: > 0 }) | |||
| { | |||
| foreach (var item in extraOutputs) | |||
| { | |||
| @@ -307,21 +314,6 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Execute the inference asynchronously. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// State arguments that are used in single inference | |||
| /// </summary> | |||
| @@ -5,9 +5,9 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| using System.Threading.Tasks; | |||
| using LLama.Extensions; | |||
| namespace LLama | |||
| @@ -60,7 +60,7 @@ namespace LLama | |||
| return state; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void LoadState(ExecutorBaseState data) | |||
| public override Task LoadState(ExecutorBaseState data) | |||
| { | |||
| if(data is InstructExecutorState state) | |||
| { | |||
| @@ -81,34 +81,37 @@ namespace LLama | |||
| { | |||
| throw new ArgumentException("Invalid state data type."); | |||
| } | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void SaveState(string filename) | |||
| public override async Task SaveState(string filename) | |||
| { | |||
| var state = (InstructExecutorState)GetStateData(); | |||
| using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | |||
| { | |||
| JsonSerializer.Serialize(fs, state); | |||
| await JsonSerializer.SerializeAsync(fs, state); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void LoadState(string filename) | |||
| public override async Task LoadState(string filename) | |||
| { | |||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | |||
| LoadState(state); | |||
| var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs); | |||
| await LoadState(state); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool GetLoopCondition(InferStateArgs args) | |||
| protected override Task<bool> GetLoopCondition(InferStateArgs args) | |||
| { | |||
| return args.RemainedTokens != 0 || _is_prompt_run; | |||
| return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||
| protected override Task PreprocessInputs(string text, InferStateArgs args) | |||
| { | |||
| args.Antiprompts ??= new List<string>(); | |||
| args.Antiprompts.Add(_instructionPrefix); | |||
| @@ -133,23 +136,24 @@ namespace LLama | |||
| args.RemainedTokens -= line_inp.Length; | |||
| } | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| { | |||
| args.WaitForInput = true; | |||
| return true; | |||
| return (true, Array.Empty<string>()); | |||
| } | |||
| if (_pastTokensCount > 0 && args.WaitForInput) | |||
| { | |||
| extraOutputs = new[] { "\n> " }; | |||
| return true; | |||
| return (true, new[] { "\n> " }); | |||
| } | |||
| } | |||
| @@ -163,10 +167,11 @@ namespace LLama | |||
| args.RemainedTokens = inferenceParams.MaxTokens; | |||
| args.WaitForInput = true; | |||
| } | |||
| return false; | |||
| return (false, Array.Empty<string>()); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| @@ -230,6 +235,8 @@ namespace LLama | |||
| } | |||
| } | |||
| } | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <summary> | |||
| /// The desciptor of the state of the instruct executor. | |||
| @@ -7,7 +7,7 @@ using System.IO; | |||
| using System.Linq; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using LLama.Extensions; | |||
| namespace LLama | |||
| @@ -51,7 +51,7 @@ namespace LLama | |||
| return state; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void LoadState(ExecutorBaseState data) | |||
| public override Task LoadState(ExecutorBaseState data) | |||
| { | |||
| if (data is InteractiveExecutorState state) | |||
| { | |||
| @@ -68,23 +68,25 @@ namespace LLama | |||
| } | |||
| else | |||
| throw new ArgumentException("Invalid state data type."); | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void SaveState(string filename) | |||
| public override async Task SaveState(string filename) | |||
| { | |||
| InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); | |||
| using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | |||
| var state = (InteractiveExecutorState)GetStateData(); | |||
| using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | |||
| { | |||
| JsonSerializer.Serialize(fs, state); | |||
| await JsonSerializer.SerializeAsync(fs, state); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public override void LoadState(string filename) | |||
| public override async Task LoadState(string filename) | |||
| { | |||
| using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs); | |||
| LoadState(state); | |||
| var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs); | |||
| await LoadState(state); | |||
| } | |||
| } | |||
| @@ -92,13 +94,13 @@ namespace LLama | |||
| /// Define whether to continue the loop to generate responses. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected override bool GetLoopCondition(InferStateArgs args) | |||
| protected override Task<bool> GetLoopCondition(InferStateArgs args) | |||
| { | |||
| return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run; | |||
| return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||
| protected override Task PreprocessInputs(string text, InferStateArgs args) | |||
| { | |||
| if (_is_prompt_run) | |||
| { | |||
| @@ -115,6 +117,8 @@ namespace LLama | |||
| _embed_inps.AddRange(line_inp); | |||
| args.RemainedTokens -= line_inp.Length; | |||
| } | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <summary> | |||
| @@ -122,24 +126,21 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="inferenceParams"></param> | |||
| /// <param name="args"></param> | |||
| /// <param name="extraOutputs"></param> | |||
| /// <returns></returns> | |||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||
| protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| extraOutputs = null; | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| args.WaitForInput = true; | |||
| if (_pastTokensCount > 0 && args.WaitForInput) | |||
| return true; | |||
| return (true, Array.Empty<string>()); | |||
| } | |||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | |||
| { | |||
| extraOutputs = new[] { " [end of text]\n" }; | |||
| return true; | |||
| return (true, new[] { " [end of text]\n" }); | |||
| } | |||
| if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) | |||
| @@ -147,11 +148,12 @@ namespace LLama | |||
| args.RemainedTokens = inferenceParams.MaxTokens; | |||
| args.WaitForInput = true; | |||
| } | |||
| return false; | |||
| return (false, Array.Empty<string>()); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| @@ -55,7 +55,7 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| using var context = _weights.CreateContext(_params); | |||
| Context = context; | |||
| @@ -140,14 +140,5 @@ namespace LLama | |||
| { | |||
| return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| } | |||
| } | |||
| } | |||
| } | |||