| @@ -28,7 +28,7 @@ if(version == 1) | |||||
| var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); | ||||
| //string prompt = " Qeustion: how to do binary search for an array in C#? Answer: "; | //string prompt = " Qeustion: how to do binary search for an array in C#? Answer: "; | ||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337))); | |||||
| InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" })); | ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" })); | ||||
| @@ -39,6 +39,17 @@ if(version == 1) | |||||
| Console.Write(text); | Console.Write(text); | ||||
| } | } | ||||
| prompt = Console.ReadLine(); | prompt = Console.ReadLine(); | ||||
| if(prompt == "save") | |||||
| { | |||||
| session.SaveSession("./SessionState"); | |||||
| Console.WriteLine("Saved session!"); | |||||
| ex.Model.Dispose(); | |||||
| ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); | |||||
| session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" })); | |||||
| session.LoadSession("./SessionState"); | |||||
| Console.WriteLine("Loaded session!"); | |||||
| prompt = Console.ReadLine(); | |||||
| } | |||||
| } | } | ||||
| ex.Model.Dispose(); | ex.Model.Dispose(); | ||||
| @@ -1,5 +1,6 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | using System.Threading; | ||||
| @@ -10,6 +11,8 @@ namespace LLama | |||||
| { | { | ||||
| private ILLamaExecutor _executor; | private ILLamaExecutor _executor; | ||||
| private ChatHistory _history; | private ChatHistory _history; | ||||
| private static readonly string _executorStateFilename = "ExecutorState.json"; | |||||
| private static readonly string _modelStateFilename = "ModelState.st"; | |||||
| public ILLamaExecutor Executor => _executor; | public ILLamaExecutor Executor => _executor; | ||||
| public ChatHistory History => _history; | public ChatHistory History => _history; | ||||
| public SessionParams Params { get; set; } | public SessionParams Params { get; set; } | ||||
| @@ -42,6 +45,56 @@ namespace LLama | |||||
| return this; | return this; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param> | |||||
| public virtual void SaveSession(string path) | |||||
| { | |||||
| if (!Directory.Exists(path)) | |||||
| { | |||||
| Directory.CreateDirectory(path); | |||||
| } | |||||
| _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); | |||||
| if(Executor is StatelessExecutor) | |||||
| { | |||||
| } | |||||
| else if(Executor is StatefulExecutorBase statefulExecutor) | |||||
| { | |||||
| statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename)); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="path">The directory name to load the session.</param> | |||||
| public virtual void LoadSession(string path) | |||||
| { | |||||
| if (!Directory.Exists(path)) | |||||
| { | |||||
| throw new FileNotFoundException($"Directory {path} does not exist."); | |||||
| } | |||||
| _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); | |||||
| if (Executor is StatelessExecutor) | |||||
| { | |||||
| } | |||||
| else if (Executor is StatefulExecutorBase statefulExecutor) | |||||
| { | |||||
| statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename)); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method."); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the response from the LLama model with chat histories. | /// Get the response from the LLama model with chat histories. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -12,7 +12,7 @@ using System.Threading; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public abstract class ChatExecutorBase : ILLamaExecutor | |||||
| public abstract class StatefulExecutorBase : ILLamaExecutor | |||||
| { | { | ||||
| protected readonly LLamaModel _model; | protected readonly LLamaModel _model; | ||||
| protected ILLamaLogger? _logger; | protected ILLamaLogger? _logger; | ||||
| @@ -26,7 +26,7 @@ namespace LLama | |||||
| protected List<llama_token> _session_tokens = new(); | protected List<llama_token> _session_tokens = new(); | ||||
| protected FixedSizeQuene<llama_token> _last_n_tokens; | protected FixedSizeQuene<llama_token> _last_n_tokens; | ||||
| public LLamaModel Model => _model; | public LLamaModel Model => _model; | ||||
| protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null) | |||||
| protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) | |||||
| { | { | ||||
| _model = model; | _model = model; | ||||
| _logger = logger; | _logger = logger; | ||||
| @@ -38,7 +38,7 @@ namespace LLama | |||||
| _last_n_tokens = new FixedSizeQuene<llama_token>(_model.ContextSize).FillWith(0); | _last_n_tokens = new FixedSizeQuene<llama_token>(_model.ContextSize).FillWith(0); | ||||
| } | } | ||||
| public unsafe ChatExecutorBase WithSessionFile(string filename) | |||||
| public unsafe StatefulExecutorBase WithSessionFile(string filename) | |||||
| { | { | ||||
| _pathSession = filename; | _pathSession = filename; | ||||
| if (string.IsNullOrEmpty(filename)) | if (string.IsNullOrEmpty(filename)) | ||||
| @@ -11,7 +11,7 @@ using System.Text.Json.Serialization; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public class InstructExecutor : ChatExecutorBase | |||||
| public class InstructExecutor : StatefulExecutorBase | |||||
| { | { | ||||
| bool _is_prompt_run = true; | bool _is_prompt_run = true; | ||||
| llama_token[] _inp_pfx; | llama_token[] _inp_pfx; | ||||
| @@ -14,7 +14,7 @@ using System.Threading.Tasks; | |||||
| namespace LLama | namespace LLama | ||||
| { | { | ||||
| using llama_token = Int32; | using llama_token = Int32; | ||||
| public class InteractiveExecutor : ChatExecutorBase | |||||
| public class InteractiveExecutor : StatefulExecutorBase | |||||
| { | { | ||||
| bool _is_prompt_run = true; | bool _is_prompt_run = true; | ||||
| llama_token[] _llama_token_newline; | llama_token[] _llama_token_newline; | ||||