diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 9620dc4f..0a5accc5 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -178,17 +178,17 @@ public class ChatSession
/// Load a session from a session state.
///
///
+ /// If true loads transforms saved in the session state.
///
///
- public void LoadSession(SessionState state)
+ public void LoadSession(SessionState state, bool loadTransforms = true)
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
- statefulExecutor.LoadState(state.ExecutorState);
- }
- else
- {
- throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
+ if (state.ExecutorState is not null)
+ {
+ statefulExecutor.LoadState(state.ExecutorState);
+ }
}
if (state.ContextState is null)
{
@@ -199,18 +199,22 @@ public class ChatSession
Executor.Context.LoadState(state.ContextState);
}
History = new ChatHistory(state.History);
- InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
- OutputTransform = state.OutputTransform.Clone();
- HistoryTransform = state.HistoryTransform.Clone();
+ if (loadTransforms)
+ {
+ InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
+ OutputTransform = state.OutputTransform.Clone();
+ HistoryTransform = state.HistoryTransform.Clone();
+ }
}
///
/// Load a session from a directory.
///
///
+ /// If true loads transforms saved in the session state.
///
///
- public void LoadSession(string path)
+ public void LoadSession(string path, bool loadTransforms = true)
{
var state = SessionState.Load(path);
// Handle non-polymorphic serialization of executor state
@@ -219,7 +223,7 @@ public class ChatSession
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
}
- LoadSession(state);
+ LoadSession(state, loadTransforms);
}
///