diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 2a9d85df..2d146eb8 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -28,7 +28,7 @@ if(version == 1) var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); //string prompt = " Qeustion: how to do binary search for an array in C#? Answer: "; - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337))); + InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" })); @@ -39,6 +39,17 @@ if(version == 1) Console.Write(text); } prompt = Console.ReadLine(); + if(prompt == "save") + { + session.SaveSession("./SessionState"); + Console.WriteLine("Saved session!"); + ex.Model.Dispose(); + ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" })); + session.LoadSession("./SessionState"); + Console.WriteLine("Loaded session!"); + prompt = Console.ReadLine(); + } } ex.Model.Dispose(); diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index d2398004..3ff36fc1 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -1,5 +1,6 @@ using LLama.Common; using System.Collections.Generic; +using System.IO; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -10,6 +11,8 @@ namespace LLama { private ILLamaExecutor _executor; private ChatHistory _history; + private static readonly string _executorStateFilename = "ExecutorState.json"; + private static readonly string _modelStateFilename = "ModelState.st"; public ILLamaExecutor Executor => _executor; public ChatHistory History => _history; public SessionParams Params { get; set; } @@ -42,6 +45,56 @@ namespace LLama return this; } + /// + /// + /// + /// The directory name to save the session. If the directory does not exist, a new directory will be created. + public virtual void SaveSession(string path) + { + if (!Directory.Exists(path)) + { + Directory.CreateDirectory(path); + } + _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); + if(Executor is StatelessExecutor) + { + + } + else if(Executor is StatefulExecutorBase statefulExecutor) + { + statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); + } + else + { + throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); + } + } + + /// + /// + /// + /// The directory name to load the session. + public virtual void LoadSession(string path) + { + if (!Directory.Exists(path)) + { + throw new FileNotFoundException($"Directory {path} does not exist."); + } + _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); + if (Executor is StatelessExecutor) + { + + } + else if (Executor is StatefulExecutorBase statefulExecutor) + { + statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); + } + else + { + throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); + } + } + /// /// Get the response from the LLama model with chat histories. /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 18ebd6e1..715f6b2d 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -12,7 +12,7 @@ using System.Threading; namespace LLama { using llama_token = Int32; - public abstract class ChatExecutorBase : ILLamaExecutor + public abstract class StatefulExecutorBase : ILLamaExecutor { protected readonly LLamaModel _model; protected ILLamaLogger? _logger; @@ -26,7 +26,7 @@ namespace LLama protected List _session_tokens = new(); protected FixedSizeQuene _last_n_tokens; public LLamaModel Model => _model; - protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null) + protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) { _model = model; _logger = logger; @@ -38,7 +38,7 @@ namespace LLama _last_n_tokens = new FixedSizeQuene(_model.ContextSize).FillWith(0); } - public unsafe ChatExecutorBase WithSessionFile(string filename) + public unsafe StatefulExecutorBase WithSessionFile(string filename) { _pathSession = filename; if (string.IsNullOrEmpty(filename)) diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index cf0903ad..b43ac7d1 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -11,7 +11,7 @@ using System.Text.Json.Serialization; namespace LLama { using llama_token = Int32; - public class InstructExecutor : ChatExecutorBase + public class InstructExecutor : StatefulExecutorBase { bool _is_prompt_run = true; llama_token[] _inp_pfx; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index d329baf0..cdd0e47e 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -14,7 +14,7 @@ using System.Threading.Tasks; namespace LLama { using llama_token = Int32; - public class InteractiveExecutor : ChatExecutorBase + public class InteractiveExecutor : StatefulExecutorBase { bool _is_prompt_run = true; llama_token[] _llama_token_newline;