Browse Source

Make state editable by the user, add deepcopy to fields that require it

tags/0.11.0
eublefar 1 year ago
parent
commit
5f3803d23c
2 changed files with 46 additions and 28 deletions
  1. +36
    -28
      LLama/ChatSession.cs
  2. +10
    -0
      LLama/Common/ChatHistory.cs

+ 36
- 28
LLama/ChatSession.cs View File

@@ -8,6 +8,7 @@ using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using static LLama.Common.ChatHistory;
using static LLama.InteractiveExecutor;
using static LLama.LLamaContext;
using static LLama.StatefulExecutorBase;
@@ -53,19 +54,16 @@ public class ChatSession
/// </summary>
/// <param name="executor">The executor for this session</param>
/// <param name="history">History for this session</param>
/// <param name="cancellationToken">Cancellation token to stop session pre-processing</param>
/// <returns></returns>
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
ILLamaExecutor executor,
ChatHistory history,
CancellationToken cancellationToken = default)
ILLamaExecutor executor, ChatHistory history)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
var session = new ChatSession(executor, history);
await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
return session;
}

@@ -164,13 +162,13 @@ public class ChatSession
/// <returns>SessionState object representing session state in-memory</returns>
public SessionState GetSessionState()
{
return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData())
{
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
HistoryTransform = HistoryTransform,
History = History.ToJson()
};
return new SessionState(
Executor.Context.GetState(),
((StatefulExecutorBase)Executor).GetStateData(),
History,
InputTransformPipeline,
OutputTransform,
HistoryTransform);
}

/// <summary>
@@ -183,18 +181,17 @@ public class ChatSession
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.LoadState(
JsonSerializer.Deserialize(
state.ExecutorState, statefulExecutor.GetStateData().GetType()
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
);
statefulExecutor.LoadState(state.ExecutorState);
}
else
{
throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
}
Executor.Context.LoadState(state.ContextState);
History = ChatHistory.FromJson(state.History) ?? new();
History = new ChatHistory(state.History);
InputTransformPipeline = state.InputTransformPipeline.ToList();
OutputTransform = state.OutputTransform;
HistoryTransform = state.HistoryTransform;
}

/// <summary>
@@ -288,7 +285,7 @@ public class ChatSession
content = inputTransform.Transform(content);
}

await statefulExecutor.AddPromptAsync(content);
await statefulExecutor.PrefillPromptAsync(content);

History.AddMessage(AuthorRole.System, content);
return this;
@@ -593,41 +590,52 @@ public record SessionState
/// <summary>
/// Saved executor state for the session in JSON format.
/// </summary>
public string ExecutorState { get; init; }
public ExecutorBaseState ExecutorState { get; set; }

/// <summary>
/// Saved context state (KV cache) for the session.
/// </summary>
public State ContextState { get; init; }
public State ContextState { get; set; }

/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; init; } = new();
public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();

/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform { get; init; } = new LLamaTransforms.EmptyTextOutputStreamTransform();
public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; init; } = new LLamaTransforms.DefaultHistoryTransform();
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
/// <summary>
/// The JSON representation of the chat history for this session.
/// The the chat history messages for this session.
/// </summary>
public string History { get; init; } = new ChatHistory().ToJson();
public Message[] History { get; set; } = Array.Empty<Message>();

/// <summary>
/// Create a new session state.
/// </summary>
/// <param name="contextState"></param>
/// <param name="executorState"></param>
public SessionState(State contextState, ExecutorBaseState executorState)
/// <param name="history"></param>
/// <param name="inputTransformPipeline"></param>
/// <param name="outputTransform"></param>
/// <param name="historyTransform"></param>
public SessionState(
State contextState, ExecutorBaseState executorState,
ChatHistory history, List<ITextTransform> inputTransformPipeline,
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
{
ContextState = contextState;
ExecutorState = JsonSerializer.Serialize(executorState);
ExecutorState = executorState;
History = history.Messages.ToArray();
InputTransformPipeline = inputTransformPipeline.ToArray();
OutputTransform = outputTransform;
HistoryTransform = historyTransform;
}
}

+ 10
- 0
LLama/Common/ChatHistory.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;

@@ -80,6 +81,15 @@ namespace LLama.Common
[JsonConstructor]
public ChatHistory() { }

/// <summary>
/// Create a new instance of the chat history from array of messages
/// </summary>
/// <param name="messageHistory"></param>
public ChatHistory(Message[] messageHistory)
{
this.Messages = messageHistory.ToList();
}

/// <summary>
/// Add a message to the chat history
/// </summary>


Loading…
Cancel
Save