| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class ChatSessionStripRoleName | public class ChatSessionStripRoleName | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.Write(prompt); | Console.Write(prompt); | ||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class ChatSessionWithRoleName | public class ChatSessionWithRoleName | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.Write(prompt); | Console.Write(prompt); | ||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class GrammarJsonResponse | public class GrammarJsonResponse | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||||
| var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim(); | |||||
| var grammar = Grammar.Parse(gbnf, "root"); | var grammar = Grammar.Parse(gbnf, "root"); | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| @@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| Console.Write("Answer: "); | Console.Write("Answer: "); | ||||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | prompt = $"Question: {prompt?.Trim()} Answer: "; | ||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class InstructModeExecute | public class InstructModeExecute | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in executor.Infer(prompt, inferenceParams)) | |||||
| await foreach (var text in executor.InferAsync(prompt, inferenceParams)) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class SaveAndLoadSession | public class SaveAndLoadSession | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.Write(prompt); | Console.Write(prompt); | ||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class LoadAndSaveState | public class LoadAndSaveState | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion | |||||
| while (true) | while (true) | ||||
| { | { | ||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public class StatelessModeExecute | public class StatelessModeExecute | ||||
| { | { | ||||
| public static void Run() | |||||
| public static async Task Run() | |||||
| { | { | ||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion | |||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| Console.Write("Answer: "); | Console.Write("Answer: "); | ||||
| prompt = $"Question: {prompt?.Trim()} Answer: "; | prompt = $"Question: {prompt?.Trim()} Answer: "; | ||||
| foreach (var text in ex.Infer(prompt, inferenceParams)) | |||||
| await foreach (var text in ex.InferAsync(prompt, inferenceParams)) | |||||
| { | { | ||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| @@ -29,11 +29,11 @@ | |||||
| if (choice == 0) | if (choice == 0) | ||||
| { | { | ||||
| ChatSessionWithRoleName.Run(); | |||||
| await ChatSessionWithRoleName.Run(); | |||||
| } | } | ||||
| else if (choice == 1) | else if (choice == 1) | ||||
| { | { | ||||
| ChatSessionStripRoleName.Run(); | |||||
| await ChatSessionStripRoleName.Run(); | |||||
| } | } | ||||
| else if(choice == 2) | else if(choice == 2) | ||||
| { | { | ||||
| @@ -41,19 +41,19 @@ | |||||
| } | } | ||||
| else if(choice == 3) | else if(choice == 3) | ||||
| { | { | ||||
| InstructModeExecute.Run(); | |||||
| await InstructModeExecute.Run(); | |||||
| } | } | ||||
| else if(choice == 4) | else if(choice == 4) | ||||
| { | { | ||||
| StatelessModeExecute.Run(); | |||||
| await StatelessModeExecute.Run(); | |||||
| } | } | ||||
| else if(choice == 5) | else if(choice == 5) | ||||
| { | { | ||||
| SaveAndLoadSession.Run(); | |||||
| await SaveAndLoadSession.Run(); | |||||
| } | } | ||||
| else if(choice == 6) | else if(choice == 6) | ||||
| { | { | ||||
| LoadAndSaveState.Run(); | |||||
| await LoadAndSaveState.Run(); | |||||
| } | } | ||||
| else if(choice == 7) | else if(choice == 7) | ||||
| { | { | ||||
| @@ -69,7 +69,7 @@ | |||||
| } | } | ||||
| else if (choice == 10) | else if (choice == 10) | ||||
| { | { | ||||
| GrammarJsonResponse.Run(); | |||||
| await GrammarJsonResponse.Run(); | |||||
| } | } | ||||
| else if (choice == 11) | else if (choice == 11) | ||||
| { | { | ||||
| @@ -41,7 +41,7 @@ namespace LLama.Unittest | |||||
| } | } | ||||
| [Fact] | [Fact] | ||||
| public void SampleWithTrivialGrammar() | |||||
| public async Task SampleWithTrivialGrammar() | |||||
| { | { | ||||
| // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | ||||
| // we can be confident it's not what the LLM would say if not constrained by the grammar! | // we can be confident it's not what the LLM would say if not constrained by the grammar! | ||||
| @@ -66,7 +66,7 @@ namespace LLama.Unittest | |||||
| Grammar = grammar, | Grammar = grammar, | ||||
| }; | }; | ||||
| var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); | |||||
| var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); | |||||
| Assert.Equal("cat", result[0]); | Assert.Equal("cat", result[0]); | ||||
| } | } | ||||
| @@ -12,6 +12,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | ||||
| <PackageReference Include="System.Linq.Async" Version="6.0.1" /> | |||||
| <PackageReference Include="xunit" Version="2.5.0" /> | <PackageReference Include="xunit" Version="2.5.0" /> | ||||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | ||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| @@ -27,15 +27,15 @@ namespace LLama.Unittest | |||||
| } | } | ||||
| [Fact] | [Fact] | ||||
| public void Stateless() | |||||
| public async Task Stateless() | |||||
| { | { | ||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| const string question = "Question. what is a cat?\nAnswer: "; | const string question = "Question. what is a cat?\nAnswer: "; | ||||
| var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; | ||||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||||
| var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||||
| var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||||
| _testOutputHelper.WriteLine(result1); | _testOutputHelper.WriteLine(result1); | ||||
| @@ -44,7 +44,7 @@ namespace LLama.Unittest | |||||
| } | } | ||||
| [Fact] | [Fact] | ||||
| public void OutOfContext() | |||||
| public async Task OutOfContext() | |||||
| { | { | ||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| @@ -58,8 +58,8 @@ namespace LLama.Unittest | |||||
| TokensKeep = question.Length, | TokensKeep = question.Length, | ||||
| }; | }; | ||||
| var result1 = string.Join("", executor.Infer(question, @params)); | |||||
| var result2 = string.Join("", executor.Infer(question, @params)); | |||||
| var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||||
| var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); | |||||
| _testOutputHelper.WriteLine(result1); | _testOutputHelper.WriteLine(result1); | ||||
| @@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers | |||||
| } | } | ||||
| [HttpPost("Send")] | [HttpPost("Send")] | ||||
| public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) | |||||
| public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) | |||||
| { | { | ||||
| return _service.Send(input); | return _service.Send(input); | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable | |||||
| _context?.Dispose(); | _context?.Dispose(); | ||||
| } | } | ||||
| public string Send(SendMessageInput input) | |||||
| public async Task<string> Send(SendMessageInput input) | |||||
| { | { | ||||
| var userInput = input.Text; | var userInput = input.Text; | ||||
| if (!_continue) | if (!_continue) | ||||
| @@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable | |||||
| Console.Write(input.Text); | Console.Write(input.Text); | ||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| var outputs = _session.Chat(userInput, new Common.InferenceParams() | |||||
| var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() | |||||
| { | { | ||||
| RepeatPenalty = 1.0f, | RepeatPenalty = 1.0f, | ||||
| AntiPrompts = new string[] { "User:" }, | AntiPrompts = new string[] { "User:" }, | ||||
| }); | }); | ||||
| var result = ""; | var result = ""; | ||||
| foreach (var output in outputs) | |||||
| await foreach (var output in outputs) | |||||
| { | { | ||||
| Console.Write(output); | Console.Write(output); | ||||
| result += output; | result += output; | ||||
| @@ -13,15 +13,6 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public LLamaContext Context { get; } | public LLamaContext Context { get; } | ||||
| /// <summary> | |||||
| /// Infers a response from the model. | |||||
| /// </summary> | |||||
| /// <param name="text">Your prompt</param> | |||||
| /// <param name="inferenceParams">Any additional parameters</param> | |||||
| /// <param name="token">A cancellation token.</param> | |||||
| /// <returns></returns> | |||||
| IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); | |||||
| /// <summary> | /// <summary> | ||||
| /// Asynchronously infers a response from the model. | /// Asynchronously infers a response from the model. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -7,13 +7,6 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public interface ITextStreamTransform | public interface ITextStreamTransform | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Takes a stream of tokens and transforms them, returning a new stream of tokens. | |||||
| /// </summary> | |||||
| /// <param name="tokens"></param> | |||||
| /// <returns></returns> | |||||
| IEnumerable<string> Transform(IEnumerable<string> tokens); | |||||
| /// <summary> | /// <summary> | ||||
| /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. | /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -134,26 +134,6 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Get the response from the LLama model with chat histories. | |||||
| /// </summary> | |||||
| /// <param name="history"></param> | |||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="cancellationToken"></param> | |||||
| /// <returns></returns> | |||||
| public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||||
| { | |||||
| var prompt = HistoryTransform.HistoryToText(history); | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||||
| StringBuilder sb = new(); | |||||
| foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) | |||||
| { | |||||
| yield return result; | |||||
| sb.Append(result); | |||||
| } | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the response from the LLama model. Note that prompt could not only be the preset words, | /// Get the response from the LLama model. Note that prompt could not only be the preset words, | ||||
| /// but also the question you want to ask. | /// but also the question you want to ask. | ||||
| @@ -162,15 +142,14 @@ namespace LLama | |||||
| /// <param name="inferenceParams"></param> | /// <param name="inferenceParams"></param> | ||||
| /// <param name="cancellationToken"></param> | /// <param name="cancellationToken"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| foreach(var inputTransform in InputTransformPipeline) | foreach(var inputTransform in InputTransformPipeline) | ||||
| { | |||||
| prompt = inputTransform.Transform(prompt); | prompt = inputTransform.Transform(prompt); | ||||
| } | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | ||||
| StringBuilder sb = new(); | StringBuilder sb = new(); | ||||
| foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) | |||||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||||
| { | { | ||||
| yield return result; | yield return result; | ||||
| sb.Append(result); | sb.Append(result); | ||||
| @@ -198,35 +177,6 @@ namespace LLama | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Get the response from the LLama model with chat histories asynchronously. | |||||
| /// </summary> | |||||
| /// <param name="prompt"></param> | |||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="cancellationToken"></param> | |||||
| /// <returns></returns> | |||||
| public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| foreach (var inputTransform in InputTransformPipeline) | |||||
| { | |||||
| prompt = inputTransform.Transform(prompt); | |||||
| } | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); | |||||
| StringBuilder sb = new(); | |||||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||||
| { | |||||
| yield return result; | |||||
| sb.Append(result); | |||||
| } | |||||
| History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); | |||||
| } | |||||
| private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||||
| { | |||||
| var results = _executor.Infer(prompt, inferenceParams, cancellationToken); | |||||
| return OutputTransform.Transform(results); | |||||
| } | |||||
| private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); | ||||
| @@ -10,6 +10,7 @@ using System.Linq; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -212,47 +213,53 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected abstract bool GetLoopCondition(InferStateArgs args); | |||||
| protected abstract Task<bool> GetLoopCondition(InferStateArgs args); | |||||
| /// <summary> | /// <summary> | ||||
| /// Preprocess the inputs before the inference. | /// Preprocess the inputs before the inference. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="text"></param> | /// <param name="text"></param> | ||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| protected abstract void PreprocessInputs(string text, InferStateArgs args); | |||||
| protected abstract Task PreprocessInputs(string text, InferStateArgs args); | |||||
| /// <summary> | /// <summary> | ||||
| /// Do some post processing after the inference. | /// Do some post processing after the inference. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="inferenceParams"></param> | /// <param name="inferenceParams"></param> | ||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| /// <param name="extraOutputs"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); | |||||
| protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); | |||||
| /// <summary> | /// <summary> | ||||
| /// The core inference logic. | /// The core inference logic. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="inferenceParams"></param> | /// <param name="inferenceParams"></param> | ||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); | |||||
| protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); | |||||
| /// <summary> | /// <summary> | ||||
| /// Save the current state to a file. | /// Save the current state to a file. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="filename"></param> | /// <param name="filename"></param> | ||||
| public abstract void SaveState(string filename); | |||||
| public abstract Task SaveState(string filename); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the current state data. | /// Get the current state data. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public abstract ExecutorBaseState GetStateData(); | public abstract ExecutorBaseState GetStateData(); | ||||
| /// <summary> | /// <summary> | ||||
| /// Load the state from data. | /// Load the state from data. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="data"></param> | /// <param name="data"></param> | ||||
| public abstract void LoadState(ExecutorBaseState data); | |||||
| public abstract Task LoadState(ExecutorBaseState data); | |||||
| /// <summary> | /// <summary> | ||||
| /// Load the state from a file. | /// Load the state from a file. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="filename"></param> | /// <param name="filename"></param> | ||||
| public abstract void LoadState(string filename); | |||||
| public abstract Task LoadState(string filename); | |||||
| /// <summary> | /// <summary> | ||||
| @@ -262,12 +269,12 @@ namespace LLama | |||||
| /// <param name="inferenceParams"></param> | /// <param name="inferenceParams"></param> | ||||
| /// <param name="cancellationToken"></param> | /// <param name="cancellationToken"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| cancellationToken.ThrowIfCancellationRequested(); | cancellationToken.ThrowIfCancellationRequested(); | ||||
| inferenceParams ??= new InferenceParams(); | inferenceParams ??= new InferenceParams(); | ||||
| InferStateArgs args = new InferStateArgs() | |||||
| var args = new InferStateArgs | |||||
| { | { | ||||
| Antiprompts = inferenceParams.AntiPrompts.ToList(), | Antiprompts = inferenceParams.AntiPrompts.ToList(), | ||||
| RemainedTokens = inferenceParams.MaxTokens, | RemainedTokens = inferenceParams.MaxTokens, | ||||
| @@ -276,15 +283,15 @@ namespace LLama | |||||
| NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count | NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count | ||||
| }; | }; | ||||
| PreprocessInputs(text, args); | |||||
| await PreprocessInputs(text, args); | |||||
| while (GetLoopCondition(args)) | |||||
| while (await GetLoopCondition(args)) | |||||
| { | { | ||||
| if (cancellationToken.IsCancellationRequested) | if (cancellationToken.IsCancellationRequested) | ||||
| { | { | ||||
| break; | break; | ||||
| } | } | ||||
| InferInternal(inferenceParams, args); | |||||
| await InferInternal(inferenceParams, args); | |||||
| if (args.ReturnValue) | if (args.ReturnValue) | ||||
| { | { | ||||
| @@ -292,8 +299,8 @@ namespace LLama | |||||
| yield return Context.TokenToString(id); | yield return Context.TokenToString(id); | ||||
| } | } | ||||
| var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs); | |||||
| if (extraOutputs is not null) | |||||
| var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); | |||||
| if (extraOutputs is { Count: > 0 }) | |||||
| { | { | ||||
| foreach (var item in extraOutputs) | foreach (var item in extraOutputs) | ||||
| { | { | ||||
| @@ -307,21 +314,6 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Execute the inference asynchronously. | |||||
| /// </summary> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="inferenceParams"></param> | |||||
| /// <param name="cancellationToken"></param> | |||||
| /// <returns></returns> | |||||
| public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||||
| { | |||||
| yield return result; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// State arguments that are used in single inference | /// State arguments that are used in single inference | ||||
| /// </summary> | /// </summary> | ||||
| @@ -5,9 +5,9 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | |||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using System.Threading.Tasks; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -60,7 +60,7 @@ namespace LLama | |||||
| return state; | return state; | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(ExecutorBaseState data) | |||||
| public override Task LoadState(ExecutorBaseState data) | |||||
| { | { | ||||
| if(data is InstructExecutorState state) | if(data is InstructExecutorState state) | ||||
| { | { | ||||
| @@ -81,34 +81,37 @@ namespace LLama | |||||
| { | { | ||||
| throw new ArgumentException("Invalid state data type."); | throw new ArgumentException("Invalid state data type."); | ||||
| } | } | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void SaveState(string filename) | |||||
| public override async Task SaveState(string filename) | |||||
| { | { | ||||
| var state = (InstructExecutorState)GetStateData(); | var state = (InstructExecutorState)GetStateData(); | ||||
| using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | ||||
| { | { | ||||
| JsonSerializer.Serialize(fs, state); | |||||
| await JsonSerializer.SerializeAsync(fs, state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(string filename) | |||||
| public override async Task LoadState(string filename) | |||||
| { | { | ||||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | ||||
| { | { | ||||
| var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); | |||||
| LoadState(state); | |||||
| var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs); | |||||
| await LoadState(state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool GetLoopCondition(InferStateArgs args) | |||||
| protected override Task<bool> GetLoopCondition(InferStateArgs args) | |||||
| { | { | ||||
| return args.RemainedTokens != 0 || _is_prompt_run; | |||||
| return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||||
| protected override Task PreprocessInputs(string text, InferStateArgs args) | |||||
| { | { | ||||
| args.Antiprompts ??= new List<string>(); | args.Antiprompts ??= new List<string>(); | ||||
| args.Antiprompts.Add(_instructionPrefix); | args.Antiprompts.Add(_instructionPrefix); | ||||
| @@ -133,23 +136,24 @@ namespace LLama | |||||
| args.RemainedTokens -= line_inp.Length; | args.RemainedTokens -= line_inp.Length; | ||||
| } | } | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||||
| protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| { | { | ||||
| extraOutputs = null; | |||||
| if (_embed_inps.Count <= _consumedTokensCount) | if (_embed_inps.Count <= _consumedTokensCount) | ||||
| { | { | ||||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | ||||
| { | { | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| return true; | |||||
| return (true, Array.Empty<string>()); | |||||
| } | } | ||||
| if (_pastTokensCount > 0 && args.WaitForInput) | if (_pastTokensCount > 0 && args.WaitForInput) | ||||
| { | { | ||||
| extraOutputs = new[] { "\n> " }; | |||||
| return true; | |||||
| return (true, new[] { "\n> " }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -163,10 +167,11 @@ namespace LLama | |||||
| args.RemainedTokens = inferenceParams.MaxTokens; | args.RemainedTokens = inferenceParams.MaxTokens; | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| } | } | ||||
| return false; | |||||
| return (false, Array.Empty<string>()); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| { | { | ||||
| if (_embeds.Count > 0) | if (_embeds.Count > 0) | ||||
| { | { | ||||
| @@ -230,6 +235,8 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// The desciptor of the state of the instruct executor. | /// The desciptor of the state of the instruct executor. | ||||
| @@ -7,7 +7,7 @@ using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text.Json; | using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | using System.Text.Json.Serialization; | ||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| namespace LLama | namespace LLama | ||||
| @@ -51,7 +51,7 @@ namespace LLama | |||||
| return state; | return state; | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(ExecutorBaseState data) | |||||
| public override Task LoadState(ExecutorBaseState data) | |||||
| { | { | ||||
| if (data is InteractiveExecutorState state) | if (data is InteractiveExecutorState state) | ||||
| { | { | ||||
| @@ -68,23 +68,25 @@ namespace LLama | |||||
| } | } | ||||
| else | else | ||||
| throw new ArgumentException("Invalid state data type."); | throw new ArgumentException("Invalid state data type."); | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void SaveState(string filename) | |||||
| public override async Task SaveState(string filename) | |||||
| { | { | ||||
| InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); | |||||
| using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | |||||
| var state = (InteractiveExecutorState)GetStateData(); | |||||
| using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) | |||||
| { | { | ||||
| JsonSerializer.Serialize(fs, state); | |||||
| await JsonSerializer.SerializeAsync(fs, state); | |||||
| } | } | ||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override void LoadState(string filename) | |||||
| public override async Task LoadState(string filename) | |||||
| { | { | ||||
| using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) | |||||
| { | { | ||||
| var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs); | |||||
| LoadState(state); | |||||
| var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs); | |||||
| await LoadState(state); | |||||
| } | } | ||||
| } | } | ||||
| @@ -92,13 +94,13 @@ namespace LLama | |||||
| /// Define whether to continue the loop to generate responses. | /// Define whether to continue the loop to generate responses. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected override bool GetLoopCondition(InferStateArgs args) | |||||
| protected override Task<bool> GetLoopCondition(InferStateArgs args) | |||||
| { | { | ||||
| return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run; | |||||
| return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override void PreprocessInputs(string text, InferStateArgs args) | |||||
| protected override Task PreprocessInputs(string text, InferStateArgs args) | |||||
| { | { | ||||
| if (_is_prompt_run) | if (_is_prompt_run) | ||||
| { | { | ||||
| @@ -115,6 +117,8 @@ namespace LLama | |||||
| _embed_inps.AddRange(line_inp); | _embed_inps.AddRange(line_inp); | ||||
| args.RemainedTokens -= line_inp.Length; | args.RemainedTokens -= line_inp.Length; | ||||
| } | } | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -122,24 +126,21 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="inferenceParams"></param> | /// <param name="inferenceParams"></param> | ||||
| /// <param name="args"></param> | /// <param name="args"></param> | ||||
| /// <param name="extraOutputs"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) | |||||
| protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| { | { | ||||
| extraOutputs = null; | |||||
| if (_embed_inps.Count <= _consumedTokensCount) | if (_embed_inps.Count <= _consumedTokensCount) | ||||
| { | { | ||||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| if (_pastTokensCount > 0 && args.WaitForInput) | if (_pastTokensCount > 0 && args.WaitForInput) | ||||
| return true; | |||||
| return (true, Array.Empty<string>()); | |||||
| } | } | ||||
| if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) | ||||
| { | { | ||||
| extraOutputs = new[] { " [end of text]\n" }; | |||||
| return true; | |||||
| return (true, new[] { " [end of text]\n" }); | |||||
| } | } | ||||
| if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) | if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) | ||||
| @@ -147,11 +148,12 @@ namespace LLama | |||||
| args.RemainedTokens = inferenceParams.MaxTokens; | args.RemainedTokens = inferenceParams.MaxTokens; | ||||
| args.WaitForInput = true; | args.WaitForInput = true; | ||||
| } | } | ||||
| return false; | |||||
| return (false, Array.Empty<string>()); | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| { | { | ||||
| if (_embeds.Count > 0) | if (_embeds.Count > 0) | ||||
| { | { | ||||
| @@ -55,7 +55,7 @@ namespace LLama | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) | |||||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | { | ||||
| using var context = _weights.CreateContext(_params); | using var context = _weights.CreateContext(_params); | ||||
| Context = context; | Context = context; | ||||
| @@ -140,14 +140,5 @@ namespace LLama | |||||
| { | { | ||||
| return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); | return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||||
| { | |||||
| foreach (var result in Infer(text, inferenceParams, cancellationToken)) | |||||
| { | |||||
| yield return result; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||