| @@ -1,11 +1,14 @@ | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using static LLama.InteractiveExecutor; | |||
| namespace LLama | |||
| { | |||
| @@ -151,11 +154,17 @@ namespace LLama | |||
| History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); | |||
| string internalPrompt = HistoryTransform.HistoryToText(History); | |||
| if (_executor is InteractiveExecutor executor) | |||
| { | |||
| InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); | |||
| prompt = state.IsPromptRun | |||
| ? HistoryTransform.HistoryToText(History) | |||
| : prompt; | |||
| } | |||
| StringBuilder sb = new(); | |||
| await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken)) | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| @@ -190,14 +199,28 @@ namespace LLama | |||
| /// <returns></returns> | |||
| public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |||
| { | |||
| var prompt = HistoryTransform.HistoryToText(history); | |||
| if (history.Messages.Count == 0) | |||
| { | |||
| throw new ArgumentException("History must contain at least one message."); | |||
| } | |||
| StringBuilder sb = new(); | |||
| string prompt; | |||
| if (_executor is InteractiveExecutor executor) | |||
| { | |||
| InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); | |||
| prompt = state.IsPromptRun | |||
| ? HistoryTransform.HistoryToText(History) | |||
| : history.Messages.Last().Content; | |||
| } | |||
| else | |||
| { | |||
| prompt = history.Messages.Last().Content; | |||
| } | |||
| await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) | |||
| { | |||
| yield return result; | |||
| sb.Append(result); | |||
| } | |||
| } | |||