diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs
index 2a9d85df..2d146eb8 100644
--- a/LLama.Examples/Program.cs
+++ b/LLama.Examples/Program.cs
@@ -28,7 +28,7 @@ if(version == 1)
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
//string prompt = " Qeustion: how to do binary search for an array in C#? Answer: ";
- InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337)));
+ InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }));
@@ -39,6 +39,17 @@ if(version == 1)
Console.Write(text);
}
prompt = Console.ReadLine();
+ if(prompt == "save")
+ {
+ session.SaveSession("./SessionState");
+ Console.WriteLine("Saved session!");
+ ex.Model.Dispose();
+ ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
+ session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }));
+ session.LoadSession("./SessionState");
+ Console.WriteLine("Loaded session!");
+ prompt = Console.ReadLine();
+ }
}
ex.Model.Dispose();
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index d2398004..3ff36fc1 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -1,5 +1,6 @@
using LLama.Common;
using System.Collections.Generic;
+using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
@@ -10,6 +11,8 @@ namespace LLama
{
private ILLamaExecutor _executor;
private ChatHistory _history;
+ private static readonly string _executorStateFilename = "ExecutorState.json";
+ private static readonly string _modelStateFilename = "ModelState.st";
public ILLamaExecutor Executor => _executor;
public ChatHistory History => _history;
public SessionParams Params { get; set; }
@@ -42,6 +45,56 @@ namespace LLama
return this;
}
+ ///
+ ///
+ ///
+ /// The directory name to save the session. If the directory does not exist, a new directory will be created.
+ public virtual void SaveSession(string path)
+ {
+ if (!Directory.Exists(path))
+ {
+ Directory.CreateDirectory(path);
+ }
+ _executor.Model.SaveState(Path.Combine(path, _modelStateFilename));
+ if(Executor is StatelessExecutor)
+ {
+
+ }
+ else if(Executor is StatefulExecutorBase statefulExecutor)
+ {
+ statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
+ }
+ else
+ {
+ throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
+ }
+ }
+
+ ///
+ ///
+ ///
+ /// The directory name to load the session.
+ public virtual void LoadSession(string path)
+ {
+ if (!Directory.Exists(path))
+ {
+ throw new FileNotFoundException($"Directory {path} does not exist.");
+ }
+ _executor.Model.LoadState(Path.Combine(path, _modelStateFilename));
+ if (Executor is StatelessExecutor)
+ {
+
+ }
+ else if (Executor is StatefulExecutorBase statefulExecutor)
+ {
+ statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
+ }
+ else
+ {
+ throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
+ }
+ }
+
///
/// Get the response from the LLama model with chat histories.
///
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 18ebd6e1..715f6b2d 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -12,7 +12,7 @@ using System.Threading;
namespace LLama
{
using llama_token = Int32;
- public abstract class ChatExecutorBase : ILLamaExecutor
+ public abstract class StatefulExecutorBase : ILLamaExecutor
{
protected readonly LLamaModel _model;
protected ILLamaLogger? _logger;
@@ -26,7 +26,7 @@ namespace LLama
protected List _session_tokens = new();
protected FixedSizeQuene _last_n_tokens;
public LLamaModel Model => _model;
- protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
+ protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
{
_model = model;
_logger = logger;
@@ -38,7 +38,7 @@ namespace LLama
_last_n_tokens = new FixedSizeQuene(_model.ContextSize).FillWith(0);
}
- public unsafe ChatExecutorBase WithSessionFile(string filename)
+ public unsafe StatefulExecutorBase WithSessionFile(string filename)
{
_pathSession = filename;
if (string.IsNullOrEmpty(filename))
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index cf0903ad..b43ac7d1 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -11,7 +11,7 @@ using System.Text.Json.Serialization;
namespace LLama
{
using llama_token = Int32;
- public class InstructExecutor : ChatExecutorBase
+ public class InstructExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _inp_pfx;
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index d329baf0..cdd0e47e 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -14,7 +14,7 @@ using System.Threading.Tasks;
namespace LLama
{
using llama_token = Int32;
- public class InteractiveExecutor : ChatExecutorBase
+ public class InteractiveExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;