diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index d1504a08..2985bd5f 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -1,243 +1,496 @@ -using LLama.Abstractions; -using LLama.Common; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using static LLama.InteractiveExecutor; -namespace LLama +namespace LLama; + +/// +/// The main chat session class. +/// +public class ChatSession { + private const string _modelStateFilename = "ModelState.st"; + private const string _executorStateFilename = "ExecutorState.json"; + private const string _hsitoryFilename = "ChatHistory.json"; + /// - /// The main chat session class. - /// - public class ChatSession - { - private readonly ILLamaExecutor _executor; - private readonly ChatHistory _history; - - private const string _executorStateFilename = "ExecutorState.json"; - private const string _modelStateFilename = "ModelState.st"; - - /// - /// The executor for this session. - /// - public ILLamaExecutor Executor => _executor; - /// - /// The chat history for this session. - /// - public ChatHistory History => _history; - /// - /// The history transform used in this session. - /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); - /// - /// The input transform pipeline used in this session. - /// - public List InputTransformPipeline { get; set; } = new(); - /// - /// The output transform used in this session. - /// - public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); - - /// - /// - /// - /// The executor for this session - public ChatSession(ILLamaExecutor executor) - { - _executor = executor; - _history = new ChatHistory(); - } - - /// - /// Use a custom history transform. - /// - /// - /// - public ChatSession WithHistoryTransform(IHistoryTransform transform) - { - HistoryTransform = transform; - return this; - } - - /// - /// Add a text transform to the input transform pipeline. - /// - /// - /// - public ChatSession AddInputTransform(ITextTransform transform) - { - InputTransformPipeline.Add(transform); - return this; - } - - /// - /// Use a custom output transform. - /// - /// - /// - public ChatSession WithOutputTransform(ITextStreamTransform transform) - { - OutputTransform = transform; - return this; - } - - /// - /// - /// - /// The directory name to save the session. If the directory does not exist, a new directory will be created. - public virtual void SaveSession(string path) - { - if (!Directory.Exists(path)) - { - Directory.CreateDirectory(path); - } - _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + /// The executor for this session. + /// + public ILLamaExecutor Executor { get; private set; } - } - else if (Executor is StatefulExecutorBase statefulExecutor) - { - statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); - } - else - { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); - } + /// + /// The chat history for this session. + /// + public ChatHistory History { get; private set; } = new(); + + /// + /// The history transform used in this session. + /// + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + + /// + /// The input transform pipeline used in this session. + /// + public List InputTransformPipeline { get; set; } = new(); + + /// + /// The output transform used in this session. + /// + public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + + /// + /// Create a new chat session. + /// + /// The executor for this session + public ChatSession(ILLamaExecutor executor) + { + // Check if executor has StatefulExecutorBase as base class + if (executor is not StatefulExecutorBase) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } - /// - /// - /// - /// The directory name to load the session. - public virtual void LoadSession(string path) + Executor = executor; + } + + /// + /// Create a new chat session with a custom history. + /// + /// + /// + public ChatSession(ILLamaExecutor executor, ChatHistory history) + : this(executor) + { + History = history; + } + + /// + /// Use a custom history transform. + /// + /// + /// + public ChatSession WithHistoryTransform(IHistoryTransform transform) + { + HistoryTransform = transform; + return this; + } + + /// + /// Add a text transform to the input transform pipeline. + /// + /// + /// + public ChatSession AddInputTransform(ITextTransform transform) + { + InputTransformPipeline.Add(transform); + return this; + } + + /// + /// Use a custom output transform. + /// + /// + /// + public ChatSession WithOutputTransform(ITextStreamTransform transform) + { + OutputTransform = transform; + return this; + } + + /// + /// Save a session from a directory. + /// + /// + /// + /// + public void SaveSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) { - if (!Directory.Exists(path)) - { - throw new FileNotFoundException($"Directory {path} does not exist."); - } - _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); - if (Executor is StatelessExecutor) - { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } - } - else if (Executor is StatefulExecutorBase statefulExecutor) + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); + } + + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.SaveState(modelStateFilePath); + + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + File.WriteAllText(historyFilepath, History.ToJson()); + } + + /// + /// Load a session from a directory. + /// + /// + /// + /// + public void LoadSession(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, _modelStateFilename); + Executor.Context.LoadState(modelStateFilePath); + + string executorStateFilepath = Path.Combine(path, _executorStateFilename); + ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); + + string historyFilepath = Path.Combine(path, _hsitoryFilename); + string historyJson = File.ReadAllText(historyFilepath); + History = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + } + + /// + /// Add a message to the chat history. + /// + /// + /// + public ChatSession AddMessage(ChatHistory.Message message) + { + // If current message is a system message, only allow the history to be empty + if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) + { + throw new ArgumentException("Cannot add a system message after another message", nameof(message)); + } + + // If current message is a user message, only allow the history to be empty, + // or the previous message to be a system message or assistant message. + if (message.AuthorRole == AuthorRole.User) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { - statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); + throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); } - else + } + + // If the current message is an assistant message, + // the previous message must be a user message. + if (message.AuthorRole == AuthorRole.Assistant) + { + ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + if (lastMessage is null + || lastMessage.AuthorRole != AuthorRole.User) { - throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); + throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message)); } } - /// - /// Generates a response for a given user prompt and manages history state for the user. - /// This will always pass the whole history to the model. Don't pass a whole history - /// to this method as the user prompt will be appended to the history of the current session. - /// If more control is needed, use the other overload of this method that accepts a ChatHistory object. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + History.AddMessage(message.AuthorRole, message.Content); + return this; + } + + /// + /// Add a system message to the chat history. + /// + /// + /// + public ChatSession AddSystemMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Add an assistant message to the chat history. + /// + /// + /// + public ChatSession AddAssistantMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + + /// + /// Add a user message to the chat history. + /// + /// + /// + public ChatSession AddUserMessage(string content) + => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); + + /// + /// Remove the last message from the chat history. + /// + /// + public ChatSession RemoveLastMessage() + { + History.Messages.RemoveAt(History.Messages.Count - 1); + return this; + } + + /// + /// Replace a user message with a new message and remove all messages after the new message. + /// This is useful when the user wants to edit a message. And regenerate the response. + /// + /// + /// + /// + public ChatSession ReplaceUserMessage( + ChatHistory.Message oldMessage, + ChatHistory.Message newMessage) + { + if (oldMessage.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Old message must be a user message", nameof(oldMessage)); + } + + if (newMessage.AuthorRole != AuthorRole.User) { - foreach (var inputTransform in InputTransformPipeline) - prompt = inputTransform.Transform(prompt); + throw new ArgumentException("New message must be a user message", nameof(newMessage)); + } - History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); + int index = History.Messages.IndexOf(oldMessage); + if (index == -1) + { + throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); + } - if (_executor is InteractiveExecutor executor) - { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt)); - } + History.Messages[index] = newMessage; + + // Remove all message after the new message + History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); + + return this; + } - StringBuilder sb = new(); + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public async IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // The message must be a user message + if (message.AuthorRole != AuthorRole.User) + { + throw new ArgumentException("Message must be a user message", nameof(message)); + } - await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) + // Apply input transform pipeline + if (applyInputTransformPipeline) + { + foreach (var inputTransform in InputTransformPipeline) { - yield return textToken; - sb.Append(textToken); + message.Content = inputTransform.Transform(message.Content); } + } - string assistantMessage = sb.ToString(); + // Add the user's message to the history + AddUserMessage(message.Content); + + // Prepare prompt variable + string prompt; + + // Check if the session history was restored from a previous session + // or added as part of new chat session history. + InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData(); + + // If "IsPromptRun" is true, the session was newly started. + if (state.IsPromptRun) + { + // If the session history was added as part of new chat session history, + // convert the complete history includsing system message and manually added history + // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. + prompt = HistoryTransform.HistoryToText(History); + } + else + { + // If the session was restored from a previous session, + // convert only the current message to the prompt with the prompt template + // specified in the HistoryTransform class implementation that is provided. + ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + prompt = HistoryTransform.HistoryToText(singleMessageHistory); + } + + string assistantMessage = string.Empty; + + await foreach ( + string textToken + in ChatAsyncInternal( + prompt, + inferenceParams, + cancellationToken)) + { + assistantMessage += textToken; + yield return textToken; + } - // Remove end tokens from the assistant message - // if defined in inferenceParams.AntiPrompts. - // We only want the response that was generated and not tokens - // that are delimiting the beginning or end of the response. - if (inferenceParams?.AntiPrompts != null) + // Add the assistant message to the history + AddAssistantMessage(assistantMessage); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory.Message message, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + message, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + bool applyInputTransformPipeline, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + ChatHistory.Message lastMessage = history.Messages.LastOrDefault() + ?? throw new ArgumentException("History must contain at least one message", nameof(history)); + + foreach ( + ChatHistory.Message message + in history.Messages.Take(history.Messages.Count - 1)) + { + // Apply input transform pipeline + if (applyInputTransformPipeline + && message.AuthorRole == AuthorRole.User) { - foreach (var stopToken in inferenceParams.AntiPrompts) + foreach ( + var inputTransform + in InputTransformPipeline) { - assistantMessage = assistantMessage.Replace(stopToken, "").Trim(); + message.Content = inputTransform.Transform(message.Content); } } - History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage)); + AddMessage(message); } - /// - /// Generates a response for a given chat history. This method does not manage history state for the user. - /// If you want to e.g. truncate the history of a session to fit into the model's context window, - /// use this method and pass the truncated history to it. If you don't need this control, use the other - /// overload of this method that accepts a user prompt instead. - /// - /// - /// - /// - /// Returns generated text of the assistant message. - public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + return ChatAsync( + lastMessage, + applyInputTransformPipeline, + inferenceParams, + cancellationToken); + } + + /// + /// Chat with the model. + /// + /// + /// + /// + /// + public IAsyncEnumerable ChatAsync( + ChatHistory history, + IInferenceParams? inferenceParams = null, + CancellationToken cancellationToken = default) + { + return ChatAsync( + history, + applyInputTransformPipeline: true, + inferenceParams, + cancellationToken); + } + + /// + /// Regenerate the last assistant message. + /// + /// + /// + /// + /// + public async IAsyncEnumerable RegenerateAssistantMessageAsync( + InferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Make sure the last message is an assistant message (reponse from the LLM). + ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); + + if (lastAssistantMessage is null + || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) { - if (history.Messages.Count == 0) - { - throw new ArgumentException("History must contain at least one message."); - } + throw new InvalidOperationException("Last message must be an assistant message"); + } - string prompt; - if (_executor is InteractiveExecutor executor) - { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); + // Remove the last assistant message from the history. + RemoveLastMessage(); - if (state.IsPromptRun) - { - prompt = HistoryTransform.HistoryToText(History); - } - else - { - ChatHistory.Message lastMessage = history.Messages.Last(); - prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); - } - } - else - { - ChatHistory.Message lastMessage = history.Messages.Last(); - prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content)); - } + // Get the last user message. + ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); - await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) - { - yield return textToken; - } + if (lastUserMessage is null + || lastUserMessage.AuthorRole != AuthorRole.User) + { + throw new InvalidOperationException("Last message must be a user message"); } - private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + // Remove the last user message from the history. + RemoveLastMessage(); + + // Regenerate the assistant message. + await foreach ( + string textToken + in ChatAsync( + lastUserMessage, + applyInputTransformPipeline: false, + inferenceParams, + cancellationToken)) { - var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); - await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken)) - { - yield return textToken; - } + yield return textToken; + } + } + + private async IAsyncEnumerable ChatAsyncInternal( + string prompt, + IInferenceParams? inferenceParams = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken); + + await foreach ( + string textToken + in OutputTransform + .TransformAsync(results) + .WithCancellation(cancellationToken)) + { + yield return textToken; } } -} \ No newline at end of file +} diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 7224b314..3f038874 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,7 @@ using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Text.Json.Serialization; namespace LLama.Common { @@ -43,11 +46,14 @@ namespace LLama.Common /// /// Role of the message author, e.g. user/assistant/system /// + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("author_role")] public AuthorRole AuthorRole { get; set; } /// /// Message content /// + [JsonPropertyName("content")] public string Content { get; set; } /// @@ -65,15 +71,14 @@ namespace LLama.Common /// /// List of messages in the chat /// - public List Messages { get; } + [JsonPropertyName("messages")] + public List Messages { get; set; } = new(); /// /// Create a new instance of the chat content class /// - public ChatHistory() - { - this.Messages = new List(); - } + [JsonConstructor] + public ChatHistory() { } /// /// Add a message to the chat history @@ -84,6 +89,29 @@ namespace LLama.Common { this.Messages.Add(new Message(authorRole, content)); } - } + /// + /// Serialize the chat history to JSON + /// + /// + public string ToJson() + { + return JsonSerializer.Serialize( + this, + new JsonSerializerOptions() + { + WriteIndented = true + }); + } + + /// + /// Deserialize a chat history from JSON + /// + /// + /// + public static ChatHistory? FromJson(string json) + { + return JsonSerializer.Deserialize(json); + } + } }