diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 358d70c3..68c3c093 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -1,11 +1,14 @@
using LLama.Abstractions;
using LLama.Common;
+using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using static LLama.InteractiveExecutor;
namespace LLama
{
@@ -151,11 +154,17 @@ namespace LLama
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
- string internalPrompt = HistoryTransform.HistoryToText(History);
+ if (_executor is InteractiveExecutor executor)
+ {
+ InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
+ prompt = state.IsPromptRun
+ ? HistoryTransform.HistoryToText(History)
+ : prompt;
+ }
StringBuilder sb = new();
- await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
+ await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
@@ -190,14 +199,28 @@ namespace LLama
///
public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- var prompt = HistoryTransform.HistoryToText(history);
+ if (history.Messages.Count == 0)
+ {
+ throw new ArgumentException("History must contain at least one message.");
+ }
- StringBuilder sb = new();
+ string prompt;
+ if (_executor is InteractiveExecutor executor)
+ {
+ InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
+
+ prompt = state.IsPromptRun
+ ? HistoryTransform.HistoryToText(History)
+ : history.Messages.Last().Content;
+ }
+ else
+ {
+ prompt = history.Messages.Last().Content;
+ }
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
- sb.Append(result);
}
}