Browse Source

feat: support save and load chat session.

tags/v0.4.0
Yaohui Liu 2 years ago
parent
commit
a3b8186f20
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
5 changed files with 70 additions and 6 deletions
  1. +12
    -1
      LLama.Examples/Program.cs
  2. +53
    -0
      LLama/ChatSession.cs
  3. +3
    -3
      LLama/LLamaExecutorBase.cs
  4. +1
    -1
      LLama/LLamaInstructExecutor.cs
  5. +1
    -1
      LLama/LLamaInteractExecutor.cs

+ 12
- 1
LLama.Examples/Program.cs View File

@@ -28,7 +28,7 @@ if(version == 1)
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
//string prompt = " Qeustion: how to do binary search for an array in C#? Answer: ";

InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337)));
InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));

ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }));

@@ -39,6 +39,17 @@ if(version == 1)
Console.Write(text);
}
prompt = Console.ReadLine();
if(prompt == "save")
{
session.SaveSession("./SessionState");
Console.WriteLine("Saved session!");
ex.Model.Dispose();
ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }));
session.LoadSession("./SessionState");
Console.WriteLine("Loaded session!");
prompt = Console.ReadLine();
}
}

ex.Model.Dispose();


+ 53
- 0
LLama/ChatSession.cs View File

@@ -1,5 +1,6 @@
using LLama.Common;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
@@ -10,6 +11,8 @@ namespace LLama
{
private ILLamaExecutor _executor;
private ChatHistory _history;
private static readonly string _executorStateFilename = "ExecutorState.json";
private static readonly string _modelStateFilename = "ModelState.st";
public ILLamaExecutor Executor => _executor;
public ChatHistory History => _history;
public SessionParams Params { get; set; }
@@ -42,6 +45,56 @@ namespace LLama
return this;
}

/// <summary>
///
/// </summary>
/// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
public virtual void SaveSession(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
_executor.Model.SaveState(Path.Combine(path, _modelStateFilename));
if(Executor is StatelessExecutor)
{

}
else if(Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
}

/// <summary>
///
/// </summary>
/// <param name="path">The directory name to load the session.</param>
public virtual void LoadSession(string path)
{
if (!Directory.Exists(path))
{
throw new FileNotFoundException($"Directory {path} does not exist.");
}
_executor.Model.LoadState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{

}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
}

/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>


+ 3
- 3
LLama/LLamaExecutorBase.cs View File

@@ -12,7 +12,7 @@ using System.Threading;
namespace LLama
{
using llama_token = Int32;
public abstract class ChatExecutorBase : ILLamaExecutor
public abstract class StatefulExecutorBase : ILLamaExecutor
{
protected readonly LLamaModel _model;
protected ILLamaLogger? _logger;
@@ -26,7 +26,7 @@ namespace LLama
protected List<llama_token> _session_tokens = new();
protected FixedSizeQuene<llama_token> _last_n_tokens;
public LLamaModel Model => _model;
protected ChatExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
{
_model = model;
_logger = logger;
@@ -38,7 +38,7 @@ namespace LLama
_last_n_tokens = new FixedSizeQuene<llama_token>(_model.ContextSize).FillWith(0);
}

public unsafe ChatExecutorBase WithSessionFile(string filename)
public unsafe StatefulExecutorBase WithSessionFile(string filename)
{
_pathSession = filename;
if (string.IsNullOrEmpty(filename))


+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -11,7 +11,7 @@ using System.Text.Json.Serialization;
namespace LLama
{
using llama_token = Int32;
public class InstructExecutor : ChatExecutorBase
public class InstructExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _inp_pfx;


+ 1
- 1
LLama/LLamaInteractExecutor.cs View File

@@ -14,7 +14,7 @@ using System.Threading.Tasks;
namespace LLama
{
using llama_token = Int32;
public class InteractiveExecutor : ChatExecutorBase
public class InteractiveExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;


Loading…
Cancel
Save