diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index d1504a08..2985bd5f 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -1,243 +1,496 @@
-using LLama.Abstractions;
-using LLama.Common;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
-using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Common;
using static LLama.InteractiveExecutor;
-namespace LLama
+namespace LLama;
+
+///
+/// The main chat session class.
+///
+public class ChatSession
{
+ private const string _modelStateFilename = "ModelState.st";
+ private const string _executorStateFilename = "ExecutorState.json";
+ private const string _hsitoryFilename = "ChatHistory.json";
+
///
- /// The main chat session class.
- ///
- public class ChatSession
- {
- private readonly ILLamaExecutor _executor;
- private readonly ChatHistory _history;
-
- private const string _executorStateFilename = "ExecutorState.json";
- private const string _modelStateFilename = "ModelState.st";
-
- ///
- /// The executor for this session.
- ///
- public ILLamaExecutor Executor => _executor;
- ///
- /// The chat history for this session.
- ///
- public ChatHistory History => _history;
- ///
- /// The history transform used in this session.
- ///
- public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
- ///
- /// The input transform pipeline used in this session.
- ///
- public List InputTransformPipeline { get; set; } = new();
- ///
- /// The output transform used in this session.
- ///
- public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
-
- ///
- ///
- ///
- /// The executor for this session
- public ChatSession(ILLamaExecutor executor)
- {
- _executor = executor;
- _history = new ChatHistory();
- }
-
- ///
- /// Use a custom history transform.
- ///
- ///
- ///
- public ChatSession WithHistoryTransform(IHistoryTransform transform)
- {
- HistoryTransform = transform;
- return this;
- }
-
- ///
- /// Add a text transform to the input transform pipeline.
- ///
- ///
- ///
- public ChatSession AddInputTransform(ITextTransform transform)
- {
- InputTransformPipeline.Add(transform);
- return this;
- }
-
- ///
- /// Use a custom output transform.
- ///
- ///
- ///
- public ChatSession WithOutputTransform(ITextStreamTransform transform)
- {
- OutputTransform = transform;
- return this;
- }
-
- ///
- ///
- ///
- /// The directory name to save the session. If the directory does not exist, a new directory will be created.
- public virtual void SaveSession(string path)
- {
- if (!Directory.Exists(path))
- {
- Directory.CreateDirectory(path);
- }
- _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
- if (Executor is StatelessExecutor)
- {
+ /// The executor for this session.
+ ///
+ public ILLamaExecutor Executor { get; private set; }
- }
- else if (Executor is StatefulExecutorBase statefulExecutor)
- {
- statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
- }
- else
- {
- throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
- }
+ ///
+ /// The chat history for this session.
+ ///
+ public ChatHistory History { get; private set; } = new();
+
+ ///
+ /// The history transform used in this session.
+ ///
+ public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
+
+ ///
+ /// The input transform pipeline used in this session.
+ ///
+ public List InputTransformPipeline { get; set; } = new();
+
+ ///
+ /// The output transform used in this session.
+ ///
+ public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
+
+ ///
+ /// Create a new chat session.
+ ///
+ /// The executor for this session
+ public ChatSession(ILLamaExecutor executor)
+ {
+ // Check if executor has StatefulExecutorBase as base class
+ if (executor is not StatefulExecutorBase)
+ {
+ throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
- ///
- ///
- ///
- /// The directory name to load the session.
- public virtual void LoadSession(string path)
+ Executor = executor;
+ }
+
+ ///
+ /// Create a new chat session with a custom history.
+ ///
+ ///
+ ///
+ public ChatSession(ILLamaExecutor executor, ChatHistory history)
+ : this(executor)
+ {
+ History = history;
+ }
+
+ ///
+ /// Use a custom history transform.
+ ///
+ ///
+ ///
+ public ChatSession WithHistoryTransform(IHistoryTransform transform)
+ {
+ HistoryTransform = transform;
+ return this;
+ }
+
+ ///
+ /// Add a text transform to the input transform pipeline.
+ ///
+ ///
+ ///
+ public ChatSession AddInputTransform(ITextTransform transform)
+ {
+ InputTransformPipeline.Add(transform);
+ return this;
+ }
+
+ ///
+ /// Use a custom output transform.
+ ///
+ ///
+ ///
+ public ChatSession WithOutputTransform(ITextStreamTransform transform)
+ {
+ OutputTransform = transform;
+ return this;
+ }
+
+ ///
+ /// Save a session from a directory.
+ ///
+ ///
+ ///
+ ///
+ public void SaveSession(string path)
+ {
+ if (string.IsNullOrWhiteSpace(path))
{
- if (!Directory.Exists(path))
- {
- throw new FileNotFoundException($"Directory {path} does not exist.");
- }
- _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
- if (Executor is StatelessExecutor)
- {
+ throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
+ }
- }
- else if (Executor is StatefulExecutorBase statefulExecutor)
+ if (Directory.Exists(path))
+ {
+ Directory.Delete(path, recursive: true);
+ }
+
+ Directory.CreateDirectory(path);
+
+ string modelStateFilePath = Path.Combine(path, _modelStateFilename);
+ Executor.Context.SaveState(modelStateFilePath);
+
+ string executorStateFilepath = Path.Combine(path, _executorStateFilename);
+ ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);
+
+ string historyFilepath = Path.Combine(path, _hsitoryFilename);
+ File.WriteAllText(historyFilepath, History.ToJson());
+ }
+
+ ///
+ /// Load a session from a directory.
+ ///
+ ///
+ ///
+ ///
+ public void LoadSession(string path)
+ {
+ if (string.IsNullOrWhiteSpace(path))
+ {
+ throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
+ }
+
+ if (!Directory.Exists(path))
+ {
+ throw new ArgumentException("Directory does not exist", nameof(path));
+ }
+
+ string modelStateFilePath = Path.Combine(path, _modelStateFilename);
+ Executor.Context.LoadState(modelStateFilePath);
+
+ string executorStateFilepath = Path.Combine(path, _executorStateFilename);
+ ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);
+
+ string historyFilepath = Path.Combine(path, _hsitoryFilename);
+ string historyJson = File.ReadAllText(historyFilepath);
+ History = ChatHistory.FromJson(historyJson)
+ ?? throw new ArgumentException("History file is invalid", nameof(path));
+ }
+
+ ///
+ /// Add a message to the chat history.
+ ///
+ ///
+ ///
+ public ChatSession AddMessage(ChatHistory.Message message)
+ {
+ // If current message is a system message, only allow the history to be empty
+ if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
+ {
+ throw new ArgumentException("Cannot add a system message after another message", nameof(message));
+ }
+
+ // If current message is a user message, only allow the history to be empty,
+ // or the previous message to be a system message or assistant message.
+ if (message.AuthorRole == AuthorRole.User)
+ {
+ ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
+ if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
{
- statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
+ throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
}
- else
+ }
+
+ // If the current message is an assistant message,
+ // the previous message must be a user message.
+ if (message.AuthorRole == AuthorRole.Assistant)
+ {
+ ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
+ if (lastMessage is null
+ || lastMessage.AuthorRole != AuthorRole.User)
{
- throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
+ throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message));
}
}
- ///
- /// Generates a response for a given user prompt and manages history state for the user.
- /// This will always pass the whole history to the model. Don't pass a whole history
- /// to this method as the user prompt will be appended to the history of the current session.
- /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
- ///
- ///
- ///
- ///
- /// Returns generated text of the assistant message.
- public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ History.AddMessage(message.AuthorRole, message.Content);
+ return this;
+ }
+
+ ///
+ /// Add a system message to the chat history.
+ ///
+ ///
+ ///
+ public ChatSession AddSystemMessage(string content)
+ => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
+
+ ///
+ /// Add an assistant message to the chat history.
+ ///
+ ///
+ ///
+ public ChatSession AddAssistantMessage(string content)
+ => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
+
+ ///
+ /// Add a user message to the chat history.
+ ///
+ ///
+ ///
+ public ChatSession AddUserMessage(string content)
+ => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
+
+ ///
+ /// Remove the last message from the chat history.
+ ///
+ ///
+ public ChatSession RemoveLastMessage()
+ {
+ History.Messages.RemoveAt(History.Messages.Count - 1);
+ return this;
+ }
+
+ ///
+ /// Replace a user message with a new message and remove all messages after the new message.
+ /// This is useful when the user wants to edit a message. And regenerate the response.
+ ///
+ ///
+ ///
+ ///
+ public ChatSession ReplaceUserMessage(
+ ChatHistory.Message oldMessage,
+ ChatHistory.Message newMessage)
+ {
+ if (oldMessage.AuthorRole != AuthorRole.User)
+ {
+ throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
+ }
+
+ if (newMessage.AuthorRole != AuthorRole.User)
{
- foreach (var inputTransform in InputTransformPipeline)
- prompt = inputTransform.Transform(prompt);
+ throw new ArgumentException("New message must be a user message", nameof(newMessage));
+ }
- History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
+ int index = History.Messages.IndexOf(oldMessage);
+ if (index == -1)
+ {
+ throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
+ }
- if (_executor is InteractiveExecutor executor)
- {
- InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
- prompt = state.IsPromptRun
- ? HistoryTransform.HistoryToText(History)
- : HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt));
- }
+ History.Messages[index] = newMessage;
+
+ // Remove all message after the new message
+ History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
+
+ return this;
+ }
- StringBuilder sb = new();
+ ///
+ /// Chat with the model.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public async IAsyncEnumerable ChatAsync(
+ ChatHistory.Message message,
+ bool applyInputTransformPipeline,
+ IInferenceParams? inferenceParams = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ // The message must be a user message
+ if (message.AuthorRole != AuthorRole.User)
+ {
+ throw new ArgumentException("Message must be a user message", nameof(message));
+ }
- await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
+ // Apply input transform pipeline
+ if (applyInputTransformPipeline)
+ {
+ foreach (var inputTransform in InputTransformPipeline)
{
- yield return textToken;
- sb.Append(textToken);
+ message.Content = inputTransform.Transform(message.Content);
}
+ }
- string assistantMessage = sb.ToString();
+ // Add the user's message to the history
+ AddUserMessage(message.Content);
+
+ // Prepare prompt variable
+ string prompt;
+
+ // Check if the session history was restored from a previous session
+ // or added as part of new chat session history.
+ InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
+
+ // If "IsPromptRun" is true, the session was newly started.
+ if (state.IsPromptRun)
+ {
+ // If the session history was added as part of new chat session history,
+ // convert the complete history includsing system message and manually added history
+ // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
+ prompt = HistoryTransform.HistoryToText(History);
+ }
+ else
+ {
+ // If the session was restored from a previous session,
+ // convert only the current message to the prompt with the prompt template
+ // specified in the HistoryTransform class implementation that is provided.
+ ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
+ prompt = HistoryTransform.HistoryToText(singleMessageHistory);
+ }
+
+ string assistantMessage = string.Empty;
+
+ await foreach (
+ string textToken
+ in ChatAsyncInternal(
+ prompt,
+ inferenceParams,
+ cancellationToken))
+ {
+ assistantMessage += textToken;
+ yield return textToken;
+ }
- // Remove end tokens from the assistant message
- // if defined in inferenceParams.AntiPrompts.
- // We only want the response that was generated and not tokens
- // that are delimiting the beginning or end of the response.
- if (inferenceParams?.AntiPrompts != null)
+ // Add the assistant message to the history
+ AddAssistantMessage(assistantMessage);
+ }
+
+ ///
+ /// Chat with the model.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public IAsyncEnumerable ChatAsync(
+ ChatHistory.Message message,
+ IInferenceParams? inferenceParams = null,
+ CancellationToken cancellationToken = default)
+ {
+ return ChatAsync(
+ message,
+ applyInputTransformPipeline: true,
+ inferenceParams,
+ cancellationToken);
+ }
+
+ ///
+ /// Chat with the model.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public IAsyncEnumerable ChatAsync(
+ ChatHistory history,
+ bool applyInputTransformPipeline,
+ IInferenceParams? inferenceParams = null,
+ CancellationToken cancellationToken = default)
+ {
+ ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
+ ?? throw new ArgumentException("History must contain at least one message", nameof(history));
+
+ foreach (
+ ChatHistory.Message message
+ in history.Messages.Take(history.Messages.Count - 1))
+ {
+ // Apply input transform pipeline
+ if (applyInputTransformPipeline
+ && message.AuthorRole == AuthorRole.User)
{
- foreach (var stopToken in inferenceParams.AntiPrompts)
+ foreach (
+ var inputTransform
+ in InputTransformPipeline)
{
- assistantMessage = assistantMessage.Replace(stopToken, "").Trim();
+ message.Content = inputTransform.Transform(message.Content);
}
}
- History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
+ AddMessage(message);
}
- ///
- /// Generates a response for a given chat history. This method does not manage history state for the user.
- /// If you want to e.g. truncate the history of a session to fit into the model's context window,
- /// use this method and pass the truncated history to it. If you don't need this control, use the other
- /// overload of this method that accepts a user prompt instead.
- ///
- ///
- ///
- ///
- /// Returns generated text of the assistant message.
- public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ return ChatAsync(
+ lastMessage,
+ applyInputTransformPipeline,
+ inferenceParams,
+ cancellationToken);
+ }
+
+ ///
+ /// Chat with the model.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public IAsyncEnumerable ChatAsync(
+ ChatHistory history,
+ IInferenceParams? inferenceParams = null,
+ CancellationToken cancellationToken = default)
+ {
+ return ChatAsync(
+ history,
+ applyInputTransformPipeline: true,
+ inferenceParams,
+ cancellationToken);
+ }
+
+ ///
+ /// Regenerate the last assistant message.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public async IAsyncEnumerable RegenerateAssistantMessageAsync(
+ InferenceParams? inferenceParams = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ // Make sure the last message is an assistant message (reponse from the LLM).
+ ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
+
+ if (lastAssistantMessage is null
+ || lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
{
- if (history.Messages.Count == 0)
- {
- throw new ArgumentException("History must contain at least one message.");
- }
+ throw new InvalidOperationException("Last message must be an assistant message");
+ }
- string prompt;
- if (_executor is InteractiveExecutor executor)
- {
- InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
+ // Remove the last assistant message from the history.
+ RemoveLastMessage();
- if (state.IsPromptRun)
- {
- prompt = HistoryTransform.HistoryToText(History);
- }
- else
- {
- ChatHistory.Message lastMessage = history.Messages.Last();
- prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
- }
- }
- else
- {
- ChatHistory.Message lastMessage = history.Messages.Last();
- prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
- }
+ // Get the last user message.
+ ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
- await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
- {
- yield return textToken;
- }
+ if (lastUserMessage is null
+ || lastUserMessage.AuthorRole != AuthorRole.User)
+ {
+ throw new InvalidOperationException("Last message must be a user message");
}
- private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ // Remove the last user message from the history.
+ RemoveLastMessage();
+
+ // Regenerate the assistant message.
+ await foreach (
+ string textToken
+ in ChatAsync(
+ lastUserMessage,
+ applyInputTransformPipeline: false,
+ inferenceParams,
+ cancellationToken))
{
- var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
- await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
- {
- yield return textToken;
- }
+ yield return textToken;
+ }
+ }
+
+ private async IAsyncEnumerable ChatAsyncInternal(
+ string prompt,
+ IInferenceParams? inferenceParams = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
+
+ await foreach (
+ string textToken
+ in OutputTransform
+ .TransformAsync(results)
+ .WithCancellation(cancellationToken))
+ {
+ yield return textToken;
}
}
-}
\ No newline at end of file
+}
diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs
index 7224b314..3f038874 100644
--- a/LLama/Common/ChatHistory.cs
+++ b/LLama/Common/ChatHistory.cs
@@ -1,4 +1,7 @@
using System.Collections.Generic;
+using System.IO;
+using System.Text.Json;
+using System.Text.Json.Serialization;
namespace LLama.Common
{
@@ -43,11 +46,14 @@ namespace LLama.Common
///
/// Role of the message author, e.g. user/assistant/system
///
+ [JsonConverter(typeof(JsonStringEnumConverter))]
+ [JsonPropertyName("author_role")]
public AuthorRole AuthorRole { get; set; }
///
/// Message content
///
+ [JsonPropertyName("content")]
public string Content { get; set; }
///
@@ -65,15 +71,14 @@ namespace LLama.Common
///
/// List of messages in the chat
///
- public List Messages { get; }
+ [JsonPropertyName("messages")]
+ public List Messages { get; set; } = new();
///
/// Create a new instance of the chat content class
///
- public ChatHistory()
- {
- this.Messages = new List();
- }
+ [JsonConstructor]
+ public ChatHistory() { }
///
/// Add a message to the chat history
@@ -84,6 +89,29 @@ namespace LLama.Common
{
this.Messages.Add(new Message(authorRole, content));
}
- }
+ ///
+ /// Serialize the chat history to JSON
+ ///
+ ///
+ public string ToJson()
+ {
+ return JsonSerializer.Serialize(
+ this,
+ new JsonSerializerOptions()
+ {
+ WriteIndented = true
+ });
+ }
+
+ ///
+ /// Deserialize a chat history from JSON
+ ///
+ ///
+ ///
+ public static ChatHistory? FromJson(string json)
+ {
+ return JsonSerializer.Deserialize(json);
+ }
+ }
}