Browse Source

Rebuild ChatSession class

- Saves with serialized ChatHistory of session
- Only allows use of ChatHistory.Message (instead of raw text)
   for easy post-processing with IHistoryTransform implementation
- Provides History Management methods
- Allows user to regenerate last assistant message
tags/0.9.1
Philipp Bauer 2 years ago
parent
commit
67e6d633fd
2 changed files with 483 additions and 202 deletions
  1. +449
    -196
      LLama/ChatSession.cs
  2. +34
    -6
      LLama/Common/ChatHistory.cs

+ 449
- 196
LLama/ChatSession.cs View File

@@ -1,243 +1,496 @@
using LLama.Abstractions;
using LLama.Common;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using static LLama.InteractiveExecutor; using static LLama.InteractiveExecutor;


namespace LLama
namespace LLama;

/// <summary>
/// The main chat session class.
/// </summary>
public class ChatSession
{ {
private const string _modelStateFilename = "ModelState.st";
private const string _executorStateFilename = "ExecutorState.json";
private const string _hsitoryFilename = "ChatHistory.json";

/// <summary> /// <summary>
/// The main chat session class.
/// </summary>
public class ChatSession
{
private readonly ILLamaExecutor _executor;
private readonly ChatHistory _history;

private const string _executorStateFilename = "ExecutorState.json";
private const string _modelStateFilename = "ModelState.st";

/// <summary>
/// The executor for this session.
/// </summary>
public ILLamaExecutor Executor => _executor;
/// <summary>
/// The chat history for this session.
/// </summary>
public ChatHistory History => _history;
/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; set; } = new();
/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
///
/// </summary>
/// <param name="executor">The executor for this session</param>
public ChatSession(ILLamaExecutor executor)
{
_executor = executor;
_history = new ChatHistory();
}

/// <summary>
/// Use a custom history transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithHistoryTransform(IHistoryTransform transform)
{
HistoryTransform = transform;
return this;
}

/// <summary>
/// Add a text transform to the input transform pipeline.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession AddInputTransform(ITextTransform transform)
{
InputTransformPipeline.Add(transform);
return this;
}

/// <summary>
/// Use a custom output transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithOutputTransform(ITextStreamTransform transform)
{
OutputTransform = transform;
return this;
}

/// <summary>
///
/// </summary>
/// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
public virtual void SaveSession(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
/// The executor for this session.
/// </summary>
public ILLamaExecutor Executor { get; private set; }


}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
/// <summary>
/// The chat history for this session.
/// </summary>
public ChatHistory History { get; private set; } = new();

/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();

/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; set; } = new();

/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// Create a new chat session.
/// </summary>
/// <param name="executor">The executor for this session</param>
public ChatSession(ILLamaExecutor executor)
{
// Check if executor has StatefulExecutorBase as base class
if (executor is not StatefulExecutorBase)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
} }


/// <summary>
///
/// </summary>
/// <param name="path">The directory name to load the session.</param>
public virtual void LoadSession(string path)
Executor = executor;
}

/// <summary>
/// Create a new chat session with a custom history.
/// </summary>
/// <param name="executor"></param>
/// <param name="history"></param>
public ChatSession(ILLamaExecutor executor, ChatHistory history)
: this(executor)
{
History = history;
}

/// <summary>
/// Use a custom history transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithHistoryTransform(IHistoryTransform transform)
{
HistoryTransform = transform;
return this;
}

/// <summary>
/// Add a text transform to the input transform pipeline.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession AddInputTransform(ITextTransform transform)
{
InputTransformPipeline.Add(transform);
return this;
}

/// <summary>
/// Use a custom output transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithOutputTransform(ITextStreamTransform transform)
{
OutputTransform = transform;
return this;
}

/// <summary>
/// Save a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void SaveSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
{ {
if (!Directory.Exists(path))
{
throw new FileNotFoundException($"Directory {path} does not exist.");
}
_executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}


}
else if (Executor is StatefulExecutorBase statefulExecutor)
if (Directory.Exists(path))
{
Directory.Delete(path, recursive: true);
}

Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.SaveState(modelStateFilePath);

string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
File.WriteAllText(historyFilepath, History.ToJson());
}

/// <summary>
/// Load a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
}

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.LoadState(modelStateFilePath);

string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
string historyJson = File.ReadAllText(historyFilepath);
History = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));
}

/// <summary>
/// Add a message to the chat history.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public ChatSession AddMessage(ChatHistory.Message message)
{
// If current message is a system message, only allow the history to be empty
if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
{
throw new ArgumentException("Cannot add a system message after another message", nameof(message));
}

// If current message is a user message, only allow the history to be empty,
// or the previous message to be a system message or assistant message.
if (message.AuthorRole == AuthorRole.User)
{
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
{ {
statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
} }
else
}

// If the current message is an assistant message,
// the previous message must be a user message.
if (message.AuthorRole == AuthorRole.Assistant)
{
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
if (lastMessage is null
|| lastMessage.AuthorRole != AuthorRole.User)
{ {
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message));
} }
} }


/// <summary>
/// Generates a response for a given user prompt and manages history state for the user.
/// This will always pass the whole history to the model. Don't pass a whole history
/// to this method as the user prompt will be appended to the history of the current session.
/// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns>Returns generated text of the assistant message.</returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
History.AddMessage(message.AuthorRole, message.Content);
return this;
}

/// <summary>
/// Add a system message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddSystemMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.System, content));

/// <summary>
/// Add an assistant message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddAssistantMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));

/// <summary>
/// Add a user message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddUserMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.User, content));

/// <summary>
/// Remove the last message from the chat history.
/// </summary>
/// <returns></returns>
public ChatSession RemoveLastMessage()
{
History.Messages.RemoveAt(History.Messages.Count - 1);
return this;
}

/// <summary>
/// Replace a user message with a new message and remove all messages after the new message.
/// This is useful when the user wants to edit a message. And regenerate the response.
/// </summary>
/// <param name="oldMessage"></param>
/// <param name="newMessage"></param>
/// <returns></returns>
public ChatSession ReplaceUserMessage(
ChatHistory.Message oldMessage,
ChatHistory.Message newMessage)
{
if (oldMessage.AuthorRole != AuthorRole.User)
{
throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
}

if (newMessage.AuthorRole != AuthorRole.User)
{ {
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
throw new ArgumentException("New message must be a user message", nameof(newMessage));
}


History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
int index = History.Messages.IndexOf(oldMessage);
if (index == -1)
{
throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
}


if (_executor is InteractiveExecutor executor)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(AuthorRole.User, prompt));
}
History.Messages[index] = newMessage;
// Remove all message after the new message
History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
return this;
}


StringBuilder sb = new();
/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="message"></param>
/// <param name="inferenceParams"></param>
/// <param name="applyInputTransformPipeline"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public async IAsyncEnumerable<string> ChatAsync(
ChatHistory.Message message,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// The message must be a user message
if (message.AuthorRole != AuthorRole.User)
{
throw new ArgumentException("Message must be a user message", nameof(message));
}


await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
// Apply input transform pipeline
if (applyInputTransformPipeline)
{
foreach (var inputTransform in InputTransformPipeline)
{ {
yield return textToken;
sb.Append(textToken);
message.Content = inputTransform.Transform(message.Content);
} }
}


string assistantMessage = sb.ToString();
// Add the user's message to the history
AddUserMessage(message.Content);

// Prepare prompt variable
string prompt;

// Check if the session history was restored from a previous session
// or added as part of new chat session history.
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();

// If "IsPromptRun" is true, the session was newly started.
if (state.IsPromptRun)
{
// If the session history was added as part of new chat session history,
// convert the complete history includsing system message and manually added history
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
prompt = HistoryTransform.HistoryToText(History);
}
else
{
// If the session was restored from a previous session,
// convert only the current message to the prompt with the prompt template
// specified in the HistoryTransform class implementation that is provided.
ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
}

string assistantMessage = string.Empty;

await foreach (
string textToken
in ChatAsyncInternal(
prompt,
inferenceParams,
cancellationToken))
{
assistantMessage += textToken;
yield return textToken;
}


// Remove end tokens from the assistant message
// if defined in inferenceParams.AntiPrompts.
// We only want the response that was generated and not tokens
// that are delimiting the beginning or end of the response.
if (inferenceParams?.AntiPrompts != null)
// Add the assistant message to the history
AddAssistantMessage(assistantMessage);
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="message"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory.Message message,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
return ChatAsync(
message,
applyInputTransformPipeline: true,
inferenceParams,
cancellationToken);
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="history"></param>
/// <param name="applyInputTransformPipeline"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory history,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
?? throw new ArgumentException("History must contain at least one message", nameof(history));

foreach (
ChatHistory.Message message
in history.Messages.Take(history.Messages.Count - 1))
{
// Apply input transform pipeline
if (applyInputTransformPipeline
&& message.AuthorRole == AuthorRole.User)
{ {
foreach (var stopToken in inferenceParams.AntiPrompts)
foreach (
var inputTransform
in InputTransformPipeline)
{ {
assistantMessage = assistantMessage.Replace(stopToken, "").Trim();
message.Content = inputTransform.Transform(message.Content);
} }
} }


History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
AddMessage(message);
} }


/// <summary>
/// Generates a response for a given chat history. This method does not manage history state for the user.
/// If you want to e.g. truncate the history of a session to fit into the model's context window,
/// use this method and pass the truncated history to it. If you don't need this control, use the other
/// overload of this method that accepts a user prompt instead.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns>Returns generated text of the assistant message.</returns>
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
return ChatAsync(
lastMessage,
applyInputTransformPipeline,
inferenceParams,
cancellationToken);
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory history,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
return ChatAsync(
history,
applyInputTransformPipeline: true,
inferenceParams,
cancellationToken);
}

/// <summary>
/// Regenerate the last assistant message.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
InferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Make sure the last message is an assistant message (reponse from the LLM).
ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();

if (lastAssistantMessage is null
|| lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
{ {
if (history.Messages.Count == 0)
{
throw new ArgumentException("History must contain at least one message.");
}
throw new InvalidOperationException("Last message must be an assistant message");
}


string prompt;
if (_executor is InteractiveExecutor executor)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
// Remove the last assistant message from the history.
RemoveLastMessage();


if (state.IsPromptRun)
{
prompt = HistoryTransform.HistoryToText(History);
}
else
{
ChatHistory.Message lastMessage = history.Messages.Last();
prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
}
}
else
{
ChatHistory.Message lastMessage = history.Messages.Last();
prompt = HistoryTransform.HistoryToText(HistoryTransform.TextToHistory(lastMessage.AuthorRole, lastMessage.Content));
}
// Get the last user message.
ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();


await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return textToken;
}
if (lastUserMessage is null
|| lastUserMessage.AuthorRole != AuthorRole.User)
{
throw new InvalidOperationException("Last message must be a user message");
} }


private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
// Remove the last user message from the history.
RemoveLastMessage();

// Regenerate the assistant message.
await foreach (
string textToken
in ChatAsync(
lastUserMessage,
applyInputTransformPipeline: false,
inferenceParams,
cancellationToken))
{ {
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
await foreach (var textToken in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
{
yield return textToken;
}
yield return textToken;
}
}

private async IAsyncEnumerable<string> ChatAsyncInternal(
string prompt,
IInferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);

await foreach (
string textToken
in OutputTransform
.TransformAsync(results)
.WithCancellation(cancellationToken))
{
yield return textToken;
} }
} }
}
}

+ 34
- 6
LLama/Common/ChatHistory.cs View File

@@ -1,4 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Text.Json.Serialization;


namespace LLama.Common namespace LLama.Common
{ {
@@ -43,11 +46,14 @@ namespace LLama.Common
/// <summary> /// <summary>
/// Role of the message author, e.g. user/assistant/system /// Role of the message author, e.g. user/assistant/system
/// </summary> /// </summary>
[JsonConverter(typeof(JsonStringEnumConverter))]
[JsonPropertyName("author_role")]
public AuthorRole AuthorRole { get; set; } public AuthorRole AuthorRole { get; set; }


/// <summary> /// <summary>
/// Message content /// Message content
/// </summary> /// </summary>
[JsonPropertyName("content")]
public string Content { get; set; } public string Content { get; set; }


/// <summary> /// <summary>
@@ -65,15 +71,14 @@ namespace LLama.Common
/// <summary> /// <summary>
/// List of messages in the chat /// List of messages in the chat
/// </summary> /// </summary>
public List<Message> Messages { get; }
[JsonPropertyName("messages")]
public List<Message> Messages { get; set; } = new();


/// <summary> /// <summary>
/// Create a new instance of the chat content class /// Create a new instance of the chat content class
/// </summary> /// </summary>
public ChatHistory()
{
this.Messages = new List<Message>();
}
[JsonConstructor]
public ChatHistory() { }


/// <summary> /// <summary>
/// Add a message to the chat history /// Add a message to the chat history
@@ -84,6 +89,29 @@ namespace LLama.Common
{ {
this.Messages.Add(new Message(authorRole, content)); this.Messages.Add(new Message(authorRole, content));
} }
}


/// <summary>
/// Serialize the chat history to JSON
/// </summary>
/// <returns></returns>
public string ToJson()
{
return JsonSerializer.Serialize(
this,
new JsonSerializerOptions()
{
WriteIndented = true
});
}

/// <summary>
/// Deserialize a chat history from JSON
/// </summary>
/// <param name="json"></param>
/// <returns></returns>
public static ChatHistory? FromJson(string json)
{
return JsonSerializer.Deserialize<ChatHistory>(json);
}
}
} }

Loading…
Cancel
Save