|
|
|
@@ -8,7 +8,6 @@ using System.Threading; |
|
|
|
using System.Threading.Tasks; |
|
|
|
using LLama.Abstractions; |
|
|
|
using LLama.Common; |
|
|
|
using static LLama.Common.ChatHistory; |
|
|
|
using static LLama.InteractiveExecutor; |
|
|
|
using static LLama.LLamaContext; |
|
|
|
using static LLama.StatefulExecutorBase; |
|
|
|
@@ -20,9 +19,30 @@ namespace LLama; |
|
|
|
/// </summary> |
|
|
|
public class ChatSession |
|
|
|
{ |
|
|
|
private const string _modelStateFilename = "ModelState.st"; |
|
|
|
private const string _executorStateFilename = "ExecutorState.json"; |
|
|
|
private const string _hsitoryFilename = "ChatHistory.json"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized model state (KV cache, etc). |
|
|
|
/// </summary> |
|
|
|
public const string MODEL_STATE_FILENAME = "ModelState.st"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized executor state. |
|
|
|
/// </summary> |
|
|
|
public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized chat history. |
|
|
|
/// </summary> |
|
|
|
public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized input transform pipeline. |
|
|
|
/// </summary> |
|
|
|
public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized output transform. |
|
|
|
/// </summary> |
|
|
|
public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; |
|
|
|
/// <summary> |
|
|
|
/// The filename for the serialized history transform. |
|
|
|
/// </summary> |
|
|
|
public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// The executor for this session. |
|
|
|
@@ -134,26 +154,7 @@ public class ChatSession |
|
|
|
/// <exception cref="ArgumentException"></exception> |
|
|
|
public void SaveSession(string path) |
|
|
|
{ |
|
|
|
if (string.IsNullOrWhiteSpace(path)) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
if (Directory.Exists(path)) |
|
|
|
{ |
|
|
|
Directory.Delete(path, recursive: true); |
|
|
|
} |
|
|
|
|
|
|
|
Directory.CreateDirectory(path); |
|
|
|
|
|
|
|
string modelStateFilePath = Path.Combine(path, _modelStateFilename); |
|
|
|
Executor.Context.SaveState(modelStateFilePath); |
|
|
|
|
|
|
|
string executorStateFilepath = Path.Combine(path, _executorStateFilename); |
|
|
|
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); |
|
|
|
|
|
|
|
string historyFilepath = Path.Combine(path, _hsitoryFilename); |
|
|
|
File.WriteAllText(historyFilepath, History.ToJson()); |
|
|
|
GetSessionState().Save(path); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -202,26 +203,14 @@ public class ChatSession |
|
|
|
/// <exception cref="ArgumentException"></exception> |
|
|
|
public void LoadSession(string path) |
|
|
|
{ |
|
|
|
if (string.IsNullOrWhiteSpace(path)) |
|
|
|
var state = SessionState.Load(path); |
|
|
|
// Handle non-polymorphic serialization of executor state |
|
|
|
if (state.ExecutorState is ExecutorBaseState) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); |
|
|
|
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); |
|
|
|
((StatefulExecutorBase) Executor).LoadState(filename: executorPath); |
|
|
|
} |
|
|
|
|
|
|
|
if (!Directory.Exists(path)) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Directory does not exist", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
string modelStateFilePath = Path.Combine(path, _modelStateFilename); |
|
|
|
Executor.Context.LoadState(modelStateFilePath); |
|
|
|
|
|
|
|
string executorStateFilepath = Path.Combine(path, _executorStateFilename); |
|
|
|
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); |
|
|
|
|
|
|
|
string historyFilepath = Path.Combine(path, _hsitoryFilename); |
|
|
|
string historyJson = File.ReadAllText(historyFilepath); |
|
|
|
History = ChatHistory.FromJson(historyJson) |
|
|
|
?? throw new ArgumentException("History file is invalid", nameof(path)); |
|
|
|
LoadSession(state); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -615,7 +604,7 @@ public record SessionState |
|
|
|
/// <summary> |
|
|
|
/// The the chat history messages for this session. |
|
|
|
/// </summary> |
|
|
|
public Message[] History { get; set; } = Array.Empty<Message>(); |
|
|
|
public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>(); |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Create a new session state. |
|
|
|
@@ -638,4 +627,124 @@ public record SessionState |
|
|
|
OutputTransform = outputTransform.Clone(); |
|
|
|
HistoryTransform = historyTransform.Clone(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Save the session state to folder. |
|
|
|
/// </summary> |
|
|
|
/// <param name="path"></param> |
|
|
|
public void Save(string path) |
|
|
|
{ |
|
|
|
if (string.IsNullOrWhiteSpace(path)) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
if (Directory.Exists(path)) |
|
|
|
{ |
|
|
|
Directory.Delete(path, recursive: true); |
|
|
|
} |
|
|
|
|
|
|
|
Directory.CreateDirectory(path); |
|
|
|
|
|
|
|
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); |
|
|
|
var bytes = ContextState.ToByteArray(); |
|
|
|
File.WriteAllBytes(modelStateFilePath, bytes); |
|
|
|
|
|
|
|
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); |
|
|
|
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); |
|
|
|
|
|
|
|
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); |
|
|
|
File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); |
|
|
|
|
|
|
|
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); |
|
|
|
File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); |
|
|
|
|
|
|
|
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); |
|
|
|
File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); |
|
|
|
|
|
|
|
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); |
|
|
|
File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Load the session state from folder. |
|
|
|
/// </summary> |
|
|
|
/// <param name="path"></param> |
|
|
|
/// <returns></returns> |
|
|
|
/// <exception cref="ArgumentException">Throws when session state is incorrect</exception> |
|
|
|
public static SessionState Load(string path) |
|
|
|
{ |
|
|
|
if (string.IsNullOrWhiteSpace(path)) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
if (!Directory.Exists(path)) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Directory does not exist", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); |
|
|
|
var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath)); |
|
|
|
|
|
|
|
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); |
|
|
|
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath)) |
|
|
|
?? throw new ArgumentException("Executor state file is invalid", nameof(path)); |
|
|
|
|
|
|
|
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); |
|
|
|
string historyJson = File.ReadAllText(historyFilepath); |
|
|
|
var history = ChatHistory.FromJson(historyJson) |
|
|
|
?? throw new ArgumentException("History file is invalid", nameof(path)); |
|
|
|
|
|
|
|
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); |
|
|
|
ITextTransform[] inputTransforms; |
|
|
|
try |
|
|
|
{ |
|
|
|
inputTransforms = File.Exists(inputTransformFilepath) ? |
|
|
|
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath)) |
|
|
|
?? throw new ArgumentException("Input transform file is invalid", nameof(path))) |
|
|
|
: Array.Empty<ITextTransform>(); |
|
|
|
} |
|
|
|
catch (JsonException) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Input transform file is invalid", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); |
|
|
|
|
|
|
|
ITextStreamTransform outputTransform; |
|
|
|
try |
|
|
|
{ |
|
|
|
outputTransform = File.Exists(outputTransformFilepath) ? |
|
|
|
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath)) |
|
|
|
?? throw new ArgumentException("Output transform file is invalid", nameof(path))) |
|
|
|
: new LLamaTransforms.EmptyTextOutputStreamTransform(); |
|
|
|
} |
|
|
|
catch (JsonException) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Output transform file is invalid", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); |
|
|
|
IHistoryTransform historyTransform; |
|
|
|
try |
|
|
|
{ |
|
|
|
historyTransform = File.Exists(historyTransformFilepath) ? |
|
|
|
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath)) |
|
|
|
?? throw new ArgumentException("History transform file is invalid", nameof(path))) |
|
|
|
: new LLamaTransforms.DefaultHistoryTransform(); |
|
|
|
} |
|
|
|
catch (JsonException) |
|
|
|
{ |
|
|
|
throw new ArgumentException("History transform file is invalid", nameof(path)); |
|
|
|
} |
|
|
|
|
|
|
|
return new SessionState( |
|
|
|
contextState, |
|
|
|
executorState, |
|
|
|
history, |
|
|
|
inputTransforms.ToList(), |
|
|
|
outputTransform, |
|
|
|
historyTransform); |
|
|
|
} |
|
|
|
} |