using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; using static LLama.InteractiveExecutor; using static LLama.LLamaContext; using static LLama.StatefulExecutorBase; namespace LLama; /// /// The main chat session class. /// public class ChatSession { /// /// The filename for the serialized model state (KV cache, etc). /// public const string MODEL_STATE_FILENAME = "ModelState.st"; /// /// The filename for the serialized executor state. /// public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; /// /// The filename for the serialized chat history. /// public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; /// /// The filename for the serialized input transform pipeline. /// public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; /// /// The filename for the serialized output transform. /// public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; /// /// The filename for the serialized history transform. /// public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; /// /// The executor for this session. /// public ILLamaExecutor Executor { get; private set; } /// /// The chat history for this session. /// public ChatHistory History { get; private set; } = new(); /// /// The history transform used in this session. /// public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The input transform pipeline used in this session. /// public List InputTransformPipeline { get; set; } = new(); /// /// The output transform used in this session. /// public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); /// /// Create a new chat session and preprocess history. /// /// The executor for this session /// History for this session /// public static async Task InitializeSessionFromHistoryAsync( ILLamaExecutor executor, ChatHistory history) { if (executor is not StatefulExecutorBase statefulExecutor) { throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } var session = new ChatSession(executor, history); await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); return session; } /// /// Create a new chat session. /// /// The executor for this session public ChatSession(ILLamaExecutor executor) { // Check if executor has StatefulExecutorBase as base class if (executor is not StatefulExecutorBase) { throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } Executor = executor; } /// /// Create a new chat session with a custom history. /// /// /// public ChatSession(ILLamaExecutor executor, ChatHistory history) : this(executor) { History = history; } /// /// Use a custom history transform. /// /// /// public ChatSession WithHistoryTransform(IHistoryTransform transform) { HistoryTransform = transform; return this; } /// /// Add a text transform to the input transform pipeline. /// /// /// public ChatSession AddInputTransform(ITextTransform transform) { InputTransformPipeline.Add(transform); return this; } /// /// Use a custom output transform. /// /// /// public ChatSession WithOutputTransform(ITextStreamTransform transform) { OutputTransform = transform; return this; } /// /// Save a session from a directory. /// /// /// /// public void SaveSession(string path) { GetSessionState().Save(path); } /// /// Get the session state. /// /// SessionState object representing session state in-memory public SessionState GetSessionState() { var executorState = ((StatefulExecutorBase)Executor).GetStateData(); return new SessionState( executorState.PastTokensCount > 0 ? Executor.Context.GetState() : null, executorState, History, InputTransformPipeline, OutputTransform, HistoryTransform); } /// /// Load a session from a session state. /// /// /// If true loads transforms saved in the session state. /// /// public void LoadSession(SessionState state, bool loadTransforms = true) { if (Executor is StatefulExecutorBase statefulExecutor) { if (state.ExecutorState is not null) { statefulExecutor.LoadState(state.ExecutorState); } } if (state.ContextState is null) { Executor.Context.NativeHandle.KvCacheClear(); } else { Executor.Context.LoadState(state.ContextState); } History = new ChatHistory(state.History); if (loadTransforms) { InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); OutputTransform = state.OutputTransform.Clone(); HistoryTransform = state.HistoryTransform.Clone(); } } /// /// Load a session from a directory. /// /// /// If true loads transforms saved in the session state. /// /// public void LoadSession(string path, bool loadTransforms = true) { var state = SessionState.Load(path); // Handle non-polymorphic serialization of executor state if (state.ExecutorState is null) { var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } LoadSession(state, loadTransforms); } /// /// Add a message to the chat history. /// /// /// public ChatSession AddMessage(ChatHistory.Message message) { // If current message is a system message, only allow the history to be empty if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) { throw new ArgumentException("Cannot add a system message after another message", nameof(message)); } // If current message is a user message, only allow the history to be empty, // or the previous message to be a system message or assistant message. if (message.AuthorRole == AuthorRole.User) { ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); } } // If the current message is an assistant message, // the previous message must be a user message. if (message.AuthorRole == AuthorRole.Assistant) { ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); if (lastMessage is null || lastMessage.AuthorRole != AuthorRole.User) { throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message)); } } History.AddMessage(message.AuthorRole, message.Content); return this; } /// /// Add a system message to the chat history. /// /// /// public ChatSession AddSystemMessage(string content) => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); /// /// Add an assistant message to the chat history. /// /// /// public ChatSession AddAssistantMessage(string content) => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); /// /// Add a user message to the chat history. /// /// /// public ChatSession AddUserMessage(string content) => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); /// /// Remove the last message from the chat history. /// /// public ChatSession RemoveLastMessage() { History.Messages.RemoveAt(History.Messages.Count - 1); return this; } /// /// Compute KV cache for the message and add it to the chat history. /// /// /// public async Task AddAndProcessMessage(ChatHistory.Message message) { if (Executor is not StatefulExecutorBase statefulExecutor) { throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); } AddMessage(message); var content = message.Content; if (message.AuthorRole != AuthorRole.Assistant) { foreach (var inputTransform in InputTransformPipeline) { content = inputTransform.Transform(content); } } await statefulExecutor.PrefillPromptAsync(content); return this; } /// /// Compute KV cache for the system message and add it to the chat history. /// public Task AddAndProcessSystemMessage(string content) => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); /// /// Compute KV cache for the user message and add it to the chat history. /// public Task AddAndProcessUserMessage(string content) => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); /// /// Compute KV cache for the assistant message and add it to the chat history. /// public Task AddAndProcessAssistantMessage(string content) => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); /// /// Replace a user message with a new message and remove all messages after the new message. /// This is useful when the user wants to edit a message. And regenerate the response. /// /// /// /// public ChatSession ReplaceUserMessage( ChatHistory.Message oldMessage, ChatHistory.Message newMessage) { if (oldMessage.AuthorRole != AuthorRole.User) { throw new ArgumentException("Old message must be a user message", nameof(oldMessage)); } if (newMessage.AuthorRole != AuthorRole.User) { throw new ArgumentException("New message must be a user message", nameof(newMessage)); } int index = History.Messages.IndexOf(oldMessage); if (index == -1) { throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); } History.Messages[index] = newMessage; // Remove all message after the new message History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); return this; } /// /// Chat with the model. /// /// /// /// /// /// /// public async IAsyncEnumerable ChatAsync( ChatHistory.Message message, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // The message must be a user message if (message.AuthorRole != AuthorRole.User) { throw new ArgumentException("Message must be a user message", nameof(message)); } // Apply input transform pipeline if (applyInputTransformPipeline) { foreach (var inputTransform in InputTransformPipeline) { message.Content = inputTransform.Transform(message.Content); } } // Add the user's message to the history AddUserMessage(message.Content); // Prepare prompt variable string prompt; // Check if the session history was restored from a previous session // or added as part of new chat session history. InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData(); // If "IsPromptRun" is true, the session was newly started. if (state.IsPromptRun) { // If the session history was added as part of new chat session history, // convert the complete history includsing system message and manually added history // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. prompt = HistoryTransform.HistoryToText(History); } else { // If the session was restored from a previous session, // convert only the current message to the prompt with the prompt template // specified in the HistoryTransform class implementation that is provided. ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); prompt = HistoryTransform.HistoryToText(singleMessageHistory); } string assistantMessage = string.Empty; await foreach ( string textToken in ChatAsyncInternal( prompt, inferenceParams, cancellationToken)) { assistantMessage += textToken; yield return textToken; } // Add the assistant message to the history AddAssistantMessage(assistantMessage); } /// /// Chat with the model. /// /// /// /// /// public IAsyncEnumerable ChatAsync( ChatHistory.Message message, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { return ChatAsync( message, applyInputTransformPipeline: true, inferenceParams, cancellationToken); } /// /// Chat with the model. /// /// /// /// /// /// /// public IAsyncEnumerable ChatAsync( ChatHistory history, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { ChatHistory.Message lastMessage = history.Messages.LastOrDefault() ?? throw new ArgumentException("History must contain at least one message", nameof(history)); foreach ( ChatHistory.Message message in history.Messages.Take(history.Messages.Count - 1)) { // Apply input transform pipeline if (applyInputTransformPipeline && message.AuthorRole == AuthorRole.User) { foreach ( var inputTransform in InputTransformPipeline) { message.Content = inputTransform.Transform(message.Content); } } AddMessage(message); } return ChatAsync( lastMessage, applyInputTransformPipeline, inferenceParams, cancellationToken); } /// /// Chat with the model. /// /// /// /// /// public IAsyncEnumerable ChatAsync( ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { return ChatAsync( history, applyInputTransformPipeline: true, inferenceParams, cancellationToken); } /// /// Regenerate the last assistant message. /// /// /// /// /// public async IAsyncEnumerable RegenerateAssistantMessageAsync( InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Make sure the last message is an assistant message (reponse from the LLM). ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); if (lastAssistantMessage is null || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) { throw new InvalidOperationException("Last message must be an assistant message"); } // Remove the last assistant message from the history. RemoveLastMessage(); // Get the last user message. ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); if (lastUserMessage is null || lastUserMessage.AuthorRole != AuthorRole.User) { throw new InvalidOperationException("Last message must be a user message"); } // Remove the last user message from the history. RemoveLastMessage(); // Regenerate the assistant message. await foreach ( string textToken in ChatAsync( lastUserMessage, applyInputTransformPipeline: false, inferenceParams, cancellationToken)) { yield return textToken; } } private async IAsyncEnumerable ChatAsyncInternal( string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken); await foreach ( string textToken in OutputTransform .TransformAsync(results) .WithCancellation(cancellationToken)) { yield return textToken; } } } /// /// The state of a chat session in-memory. /// public record SessionState { /// /// Saved executor state for the session in JSON format. /// public ExecutorBaseState? ExecutorState { get; set; } /// /// Saved context state (KV cache) for the session. /// public State? ContextState { get; set; } /// /// The input transform pipeline used in this session. /// public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty(); /// /// The output transform used in this session. /// public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); /// /// The history transform used in this session. /// public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The the chat history messages for this session. /// public ChatHistory.Message[] History { get; set; } = Array.Empty(); /// /// Create a new session state. /// /// /// /// /// /// /// public SessionState( State? contextState, ExecutorBaseState executorState, ChatHistory history, List inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { ContextState = contextState; ExecutorState = executorState; History = history.Messages.ToArray(); InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); OutputTransform = outputTransform.Clone(); HistoryTransform = historyTransform.Clone(); } /// /// Save the session state to folder. /// /// public void Save(string path) { if (string.IsNullOrWhiteSpace(path)) { throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); } if (Directory.Exists(path)) { Directory.Delete(path, recursive: true); } Directory.CreateDirectory(path); string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); var bytes = ContextState?.ToByteArray(); if (bytes is not null) { File.WriteAllBytes(modelStateFilePath, bytes); } string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); } /// /// Load the session state from folder. /// /// /// /// Throws when session state is incorrect public static SessionState Load(string path) { if (string.IsNullOrWhiteSpace(path)) { throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); } if (!Directory.Exists(path)) { throw new ArgumentException("Directory does not exist", nameof(path)); } string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); var contextState = File.Exists(modelStateFilePath) ? State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) : null; string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); var history = ChatHistory.FromJson(historyJson) ?? throw new ArgumentException("History file is invalid", nameof(path)); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); ITextTransform[] inputTransforms; try { inputTransforms = File.Exists(inputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) : Array.Empty(); } catch (JsonException) { throw new ArgumentException("Input transform file is invalid", nameof(path)); } string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); ITextStreamTransform outputTransform; try { outputTransform = File.Exists(outputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) : new LLamaTransforms.EmptyTextOutputStreamTransform(); } catch (JsonException) { throw new ArgumentException("Output transform file is invalid", nameof(path)); } string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); IHistoryTransform historyTransform; try { historyTransform = File.Exists(historyTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) ?? throw new ArgumentException("History transform file is invalid", nameof(path))) : new LLamaTransforms.DefaultHistoryTransform(); } catch (JsonException) { throw new ArgumentException("History transform file is invalid", nameof(path)); } return new SessionState( contextState, executorState, history, inputTransforms.ToList(), outputTransform, historyTransform); } }