| @@ -1,11 +1,14 @@ | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Common; | using LLama.Common; | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using static LLama.InteractiveExecutor; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| @@ -151,11 +154,17 @@ namespace LLama | |||||
| History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); | History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); | ||||
| string internalPrompt = HistoryTransform.HistoryToText(History); | |||||
| if (_executor is InteractiveExecutor executor) | |||||
| { | |||||
| InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); | |||||
| prompt = state.IsPromptRun | |||||
| ? HistoryTransform.HistoryToText(History) | |||||
| : prompt; | |||||
| } | |||||
| StringBuilder sb = new(); | StringBuilder sb = new(); | ||||
| await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken)) | |||||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||||
| { | { | ||||
| yield return result; | yield return result; | ||||
| sb.Append(result); | sb.Append(result); | ||||
| @@ -190,14 +199,28 @@ namespace LLama | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||||
| { | { | ||||
| var prompt = HistoryTransform.HistoryToText(history); | |||||
| if (history.Messages.Count == 0) | |||||
| { | |||||
| throw new ArgumentException("History must contain at least one message."); | |||||
| } | |||||
| StringBuilder sb = new(); | |||||
| string prompt; | |||||
| if (_executor is InteractiveExecutor executor) | |||||
| { | |||||
| InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); | |||||
| prompt = state.IsPromptRun | |||||
| ? HistoryTransform.HistoryToText(History) | |||||
| : history.Messages.Last().Content; | |||||
| } | |||||
| else | |||||
| { | |||||
| prompt = history.Messages.Last().Content; | |||||
| } | |||||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | ||||
| { | { | ||||
| yield return result; | yield return result; | ||||
| sb.Append(result); | |||||
| } | } | ||||
| } | } | ||||