From 6ea40d15461a5243bfdd453bd6d2ede3cdaa5eaa Mon Sep 17 00:00:00 2001 From: Philipp Bauer Date: Wed, 8 Nov 2023 13:18:32 -0600 Subject: [PATCH] Use full history only when the ChatSession runs the first time --- LLama/ChatSession.cs | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 358d70c3..68c3c093 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -1,11 +1,14 @@ using LLama.Abstractions; using LLama.Common; +using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; +using static LLama.InteractiveExecutor; namespace LLama { @@ -151,11 +154,17 @@ namespace LLama History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); - string internalPrompt = HistoryTransform.HistoryToText(History); + if (_executor is InteractiveExecutor executor) + { + InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); + prompt = state.IsPromptRun + ? HistoryTransform.HistoryToText(History) + : prompt; + } StringBuilder sb = new(); - await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken)) + await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); @@ -190,14 +199,28 @@ namespace LLama /// public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var prompt = HistoryTransform.HistoryToText(history); + if (history.Messages.Count == 0) + { + throw new ArgumentException("History must contain at least one message."); + } - StringBuilder sb = new(); + string prompt; + if (_executor is InteractiveExecutor executor) + { + InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); + + prompt = state.IsPromptRun + ? HistoryTransform.HistoryToText(History) + : history.Messages.Last().Content; + } + else + { + prompt = history.Messages.Last().Content; + } await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; - sb.Append(result); } }