diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 457e7e48..358d70c3 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -95,11 +95,11 @@ namespace LLama
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
- if(Executor is StatelessExecutor)
+ if (Executor is StatelessExecutor)
{
}
- else if(Executor is StatefulExecutorBase statefulExecutor)
+ else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
@@ -135,30 +135,54 @@ namespace LLama
}
///
- /// Get the response from the LLama model. Note that prompt could not only be the preset words,
- /// but also the question you want to ask.
+ /// Generates a response for a given user prompt and manages history state for the user.
+ /// This will always pass the whole history to the model. Don't pass a whole history
+ /// to this method as the user prompt will be appended to the history of the current session.
+ /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
///
///
///
///
- ///
+ /// Returns generated tokens of the assistant message.
public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- foreach(var inputTransform in InputTransformPipeline)
+ foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
-
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
+
+ History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
+
+ string internalPrompt = HistoryTransform.HistoryToText(History);
+
StringBuilder sb = new();
- await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
+
+ await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
+
+ string assistantMessage = sb.ToString();
+
+ // Remove end tokens from the assistant message
+ // if defined in inferenceParams.AntiPrompts.
+ // We only want the response that was generated and not tokens
+ // that are delimiting the beginning or end of the response.
+ if (inferenceParams?.AntiPrompts != null)
+ {
+ foreach (var stopToken in inferenceParams.AntiPrompts)
+ {
+ assistantMessage = assistantMessage.Replace(stopToken, "");
+ }
+ }
+
+ History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
}
///
- /// Get the response from the LLama model with chat histories.
+ /// Generates a response for a given chat history. This method does not manage history state for the user.
+ /// If you want to e.g. truncate the history of a session to fit into the model's context window,
+ /// use this method and pass the truncated history to it. If you don't need this control, use the other
+ /// overload of this method that accepts a user prompt instead.
///
///
///
@@ -167,14 +191,14 @@ namespace LLama
public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
+
StringBuilder sb = new();
+
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)