|
|
|
@@ -8,6 +8,7 @@ using System.Threading; |
|
|
|
using System.Threading.Tasks; |
|
|
|
using LLama.Abstractions; |
|
|
|
using LLama.Common; |
|
|
|
using static LLama.Common.ChatHistory; |
|
|
|
using static LLama.InteractiveExecutor; |
|
|
|
using static LLama.LLamaContext; |
|
|
|
using static LLama.StatefulExecutorBase; |
|
|
|
@@ -53,19 +54,16 @@ public class ChatSession |
|
|
|
/// </summary> |
|
|
|
/// <param name="executor">The executor for this session</param> |
|
|
|
/// <param name="history">History for this session</param> |
|
|
|
/// <param name="cancellationToken">Cancellation token to stop session pre-processing</param> |
|
|
|
/// <returns></returns> |
|
|
|
public static async Task<ChatSession> InitializeSessionFromHistoryAsync( |
|
|
|
ILLamaExecutor executor, |
|
|
|
ChatHistory history, |
|
|
|
CancellationToken cancellationToken = default) |
|
|
|
ILLamaExecutor executor, ChatHistory history) |
|
|
|
{ |
|
|
|
if (executor is not StatefulExecutorBase statefulExecutor) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); |
|
|
|
} |
|
|
|
var session = new ChatSession(executor, history); |
|
|
|
await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); |
|
|
|
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); |
|
|
|
return session; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -164,13 +162,13 @@ public class ChatSession |
|
|
|
/// <returns>SessionState object representing session state in-memory</returns> |
|
|
|
public SessionState GetSessionState() |
|
|
|
{ |
|
|
|
return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData()) |
|
|
|
{ |
|
|
|
InputTransformPipeline = InputTransformPipeline, |
|
|
|
OutputTransform = OutputTransform, |
|
|
|
HistoryTransform = HistoryTransform, |
|
|
|
History = History.ToJson() |
|
|
|
}; |
|
|
|
return new SessionState( |
|
|
|
Executor.Context.GetState(), |
|
|
|
((StatefulExecutorBase)Executor).GetStateData(), |
|
|
|
History, |
|
|
|
InputTransformPipeline, |
|
|
|
OutputTransform, |
|
|
|
HistoryTransform); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -183,18 +181,17 @@ public class ChatSession |
|
|
|
{ |
|
|
|
if (Executor is StatefulExecutorBase statefulExecutor) |
|
|
|
{ |
|
|
|
statefulExecutor.LoadState( |
|
|
|
JsonSerializer.Deserialize( |
|
|
|
state.ExecutorState, statefulExecutor.GetStateData().GetType() |
|
|
|
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) |
|
|
|
); |
|
|
|
statefulExecutor.LoadState(state.ExecutorState); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); |
|
|
|
} |
|
|
|
Executor.Context.LoadState(state.ContextState); |
|
|
|
History = ChatHistory.FromJson(state.History) ?? new(); |
|
|
|
History = new ChatHistory(state.History); |
|
|
|
InputTransformPipeline = state.InputTransformPipeline.ToList(); |
|
|
|
OutputTransform = state.OutputTransform; |
|
|
|
HistoryTransform = state.HistoryTransform; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -288,7 +285,7 @@ public class ChatSession |
|
|
|
content = inputTransform.Transform(content); |
|
|
|
} |
|
|
|
|
|
|
|
await statefulExecutor.AddPromptAsync(content); |
|
|
|
await statefulExecutor.PrefillPromptAsync(content); |
|
|
|
|
|
|
|
History.AddMessage(AuthorRole.System, content); |
|
|
|
return this; |
|
|
|
@@ -593,41 +590,52 @@ public record SessionState |
|
|
|
/// <summary> |
|
|
|
/// Saved executor state for the session in JSON format. |
|
|
|
/// </summary> |
|
|
|
public string ExecutorState { get; init; } |
|
|
|
public ExecutorBaseState ExecutorState { get; set; } |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Saved context state (KV cache) for the session. |
|
|
|
/// </summary> |
|
|
|
public State ContextState { get; init; } |
|
|
|
public State ContextState { get; set; } |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The input transform pipeline used in this session. |
|
|
|
/// </summary> |
|
|
|
public List<ITextTransform> InputTransformPipeline { get; init; } = new(); |
|
|
|
public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>(); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The output transform used in this session. |
|
|
|
/// </summary> |
|
|
|
public ITextStreamTransform OutputTransform { get; init; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); |
|
|
|
public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The history transform used in this session. |
|
|
|
/// </summary> |
|
|
|
public IHistoryTransform HistoryTransform { get; init; } = new LLamaTransforms.DefaultHistoryTransform(); |
|
|
|
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The JSON representation of the chat history for this session. |
|
|
|
/// The the chat history messages for this session. |
|
|
|
/// </summary> |
|
|
|
public string History { get; init; } = new ChatHistory().ToJson(); |
|
|
|
public Message[] History { get; set; } = Array.Empty<Message>(); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Create a new session state. |
|
|
|
/// </summary> |
|
|
|
/// <param name="contextState"></param> |
|
|
|
/// <param name="executorState"></param> |
|
|
|
public SessionState(State contextState, ExecutorBaseState executorState) |
|
|
|
/// <param name="history"></param> |
|
|
|
/// <param name="inputTransformPipeline"></param> |
|
|
|
/// <param name="outputTransform"></param> |
|
|
|
/// <param name="historyTransform"></param> |
|
|
|
public SessionState( |
|
|
|
State contextState, ExecutorBaseState executorState, |
|
|
|
ChatHistory history, List<ITextTransform> inputTransformPipeline, |
|
|
|
ITextStreamTransform outputTransform, IHistoryTransform historyTransform) |
|
|
|
{ |
|
|
|
ContextState = contextState; |
|
|
|
ExecutorState = JsonSerializer.Serialize(executorState); |
|
|
|
ExecutorState = executorState; |
|
|
|
History = history.Messages.ToArray(); |
|
|
|
InputTransformPipeline = inputTransformPipeline.ToArray(); |
|
|
|
OutputTransform = outputTransform; |
|
|
|
HistoryTransform = historyTransform; |
|
|
|
} |
|
|
|
} |