From b8cd5b7ee565b17feb81bbb3330866ad9df1616c Mon Sep 17 00:00:00 2001 From: eublefar Date: Thu, 21 Mar 2024 12:18:38 +0100 Subject: [PATCH] loadTransforms flag for LoadSession methods --- LLama/ChatSession.cs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 9620dc4f..0a5accc5 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -178,17 +178,17 @@ public class ChatSession /// Load a session from a session state. /// /// + /// If true loads transforms saved in the session state. /// /// - public void LoadSession(SessionState state) + public void LoadSession(SessionState state, bool loadTransforms = true) { if (Executor is StatefulExecutorBase statefulExecutor) { - statefulExecutor.LoadState(state.ExecutorState); - } - else - { - throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); + if (state.ExecutorState is not null) + { + statefulExecutor.LoadState(state.ExecutorState); + } } if (state.ContextState is null) { @@ -199,18 +199,22 @@ public class ChatSession Executor.Context.LoadState(state.ContextState); } History = new ChatHistory(state.History); - InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); - OutputTransform = state.OutputTransform.Clone(); - HistoryTransform = state.HistoryTransform.Clone(); + if (loadTransforms) + { + InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); + OutputTransform = state.OutputTransform.Clone(); + HistoryTransform = state.HistoryTransform.Clone(); + } } /// /// Load a session from a directory. /// /// + /// If true loads transforms saved in the session state. /// /// - public void LoadSession(string path) + public void LoadSession(string path, bool loadTransforms = true) { var state = SessionState.Load(path); // Handle non-polymorphic serialization of executor state @@ -219,7 +223,7 @@ public class ChatSession var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } - LoadSession(state); + LoadSession(state, loadTransforms); } ///