diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 7ee99590..748d2ef3 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -159,15 +159,15 @@ namespace LLama InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); prompt = state.IsPromptRun ? HistoryTransform.HistoryToText(History) - : prompt; + : HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt)); } StringBuilder sb = new(); - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { - yield return result; - sb.Append(result); + yield return textToken; + sb.Append(textToken); } string assistantMessage = sb.ToString(); @@ -180,7 +180,7 @@ namespace LLama { foreach (var stopToken in inferenceParams.AntiPrompts) { - assistantMessage = assistantMessage.Replace(stopToken, ""); + assistantMessage = assistantMessage.Replace(stopToken, "").Trim(); } } @@ -209,27 +209,37 @@ namespace LLama { InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : history.Messages.Last().Content; + if (state.IsPromptRun) + { + prompt = HistoryTransform.HistoryToText(History); + } + else + { + ChatHistory.Message lastMessage = history.Messages.Last(); + prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); + } } else { - prompt = history.Messages.Last().Content; + ChatHistory.Message lastMessage = history.Messages.Last(); + prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); } - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { - yield return result; + yield return textToken; } } private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Console.ForegroundColor = ConsoleColor.Gray; + Console.WriteLine(prompt); + var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) + await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) { - yield return item; + yield return textToken; } } }