diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index c9217ae0..5651343f 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -21,5 +21,11 @@ namespace LLama.Abstractions /// The chat history as plain text. /// The updated history. ChatHistory TextToHistory(AuthorRole role, string text); + + /// + /// Copy the transform. + /// + /// + IHistoryTransform Clone(); } } diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index 2725214f..2b63299d 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -13,5 +13,11 @@ namespace LLama.Abstractions /// /// IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens); + + /// + /// Copy the transform. + /// + /// + ITextStreamTransform Clone(); } } diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index ac196644..0bfeeb7f 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -17,5 +17,11 @@ /// /// string Transform(string text); + + /// + /// Copy the transform. + /// + /// + ITextTransform Clone(); } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 1cc7d29e..b4117842 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -189,9 +189,9 @@ public class ChatSession } Executor.Context.LoadState(state.ContextState); History = new ChatHistory(state.History); - InputTransformPipeline = state.InputTransformPipeline.ToList(); - OutputTransform = state.OutputTransform; - HistoryTransform = state.HistoryTransform; + InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); + OutputTransform = state.OutputTransform.Clone(); + HistoryTransform = state.HistoryTransform.Clone(); } /// @@ -634,8 +634,8 @@ public record SessionState ContextState = contextState; ExecutorState = executorState; History = history.Messages.ToArray(); - InputTransformPipeline = inputTransformPipeline.ToArray(); - OutputTransform = outputTransform; - HistoryTransform = historyTransform; + InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); + OutputTransform = outputTransform.Clone(); + HistoryTransform = historyTransform.Clone(); } } \ No newline at end of file diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index 29c16c18..1ac0a79b 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -47,6 +47,12 @@ namespace LLama _isInstructMode = isInstructMode; } + /// + public IHistoryTransform Clone() + { + return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); + } + /// public virtual string HistoryToText(ChatHistory history) { @@ -116,6 +122,12 @@ namespace LLama { return text.Trim(); } + + /// + public ITextTransform Clone() + { + return new NaiveTextInputTransform(); + } } /// @@ -129,6 +141,12 @@ namespace LLama { return tokens; } + + /// + public ITextStreamTransform Clone() + { + return new EmptyTextOutputStreamTransform(); + } } /// @@ -157,6 +175,12 @@ namespace LLama _removeAllMatchedTokens = removeAllMatchedTokens; } + /// + public ITextStreamTransform Clone() + { + return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens); + } + /// public async IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens) {