You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Runtime.CompilerServices;
  6. using System.Text.Json;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Abstractions;
  10. using LLama.Common;
  11. using static LLama.InteractiveExecutor;
  12. using static LLama.LLamaContext;
  13. using static LLama.StatefulExecutorBase;
  14. namespace LLama;
  15. /// <summary>
  16. /// The main chat session class.
  17. /// </summary>
  18. public class ChatSession
  19. {
  20. /// <summary>
  21. /// The filename for the serialized model state (KV cache, etc).
  22. /// </summary>
  23. public const string MODEL_STATE_FILENAME = "ModelState.st";
  24. /// <summary>
  25. /// The filename for the serialized executor state.
  26. /// </summary>
  27. public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
  28. /// <summary>
  29. /// The filename for the serialized chat history.
  30. /// </summary>
  31. public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
  32. /// <summary>
  33. /// The filename for the serialized input transform pipeline.
  34. /// </summary>
  35. public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
  36. /// <summary>
  37. /// The filename for the serialized output transform.
  38. /// </summary>
  39. public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
  40. /// <summary>
  41. /// The filename for the serialized history transform.
  42. /// </summary>
  43. public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
  44. /// <summary>
  45. /// The executor for this session.
  46. /// </summary>
  47. public ILLamaExecutor Executor { get; private set; }
  48. /// <summary>
  49. /// The chat history for this session.
  50. /// </summary>
  51. public ChatHistory History { get; private set; } = new();
  52. /// <summary>
  53. /// The history transform used in this session.
  54. /// </summary>
  55. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  56. /// <summary>
  57. /// The input transform pipeline used in this session.
  58. /// </summary>
  59. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  60. /// <summary>
  61. /// The output transform used in this session.
  62. /// </summary>
  63. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  64. /// <summary>
  65. /// Create a new chat session and preprocess history.
  66. /// </summary>
  67. /// <param name="executor">The executor for this session</param>
  68. /// <param name="history">History for this session</param>
  69. /// <returns></returns>
  70. public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
  71. ILLamaExecutor executor, ChatHistory history)
  72. {
  73. if (executor is not StatefulExecutorBase statefulExecutor)
  74. {
  75. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  76. }
  77. var session = new ChatSession(executor, history);
  78. await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
  79. return session;
  80. }
  81. /// <summary>
  82. /// Create a new chat session.
  83. /// </summary>
  84. /// <param name="executor">The executor for this session</param>
  85. public ChatSession(ILLamaExecutor executor)
  86. {
  87. // Check if executor has StatefulExecutorBase as base class
  88. if (executor is not StatefulExecutorBase)
  89. {
  90. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  91. }
  92. Executor = executor;
  93. }
  94. /// <summary>
  95. /// Create a new chat session with a custom history.
  96. /// </summary>
  97. /// <param name="executor"></param>
  98. /// <param name="history"></param>
  99. public ChatSession(ILLamaExecutor executor, ChatHistory history)
  100. : this(executor)
  101. {
  102. History = history;
  103. }
  104. /// <summary>
  105. /// Use a custom history transform.
  106. /// </summary>
  107. /// <param name="transform"></param>
  108. /// <returns></returns>
  109. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  110. {
  111. HistoryTransform = transform;
  112. return this;
  113. }
  114. /// <summary>
  115. /// Add a text transform to the input transform pipeline.
  116. /// </summary>
  117. /// <param name="transform"></param>
  118. /// <returns></returns>
  119. public ChatSession AddInputTransform(ITextTransform transform)
  120. {
  121. InputTransformPipeline.Add(transform);
  122. return this;
  123. }
  124. /// <summary>
  125. /// Use a custom output transform.
  126. /// </summary>
  127. /// <param name="transform"></param>
  128. /// <returns></returns>
  129. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  130. {
  131. OutputTransform = transform;
  132. return this;
  133. }
  134. /// <summary>
  135. /// Save a session from a directory.
  136. /// </summary>
  137. /// <param name="path"></param>
  138. /// <returns></returns>
  139. /// <exception cref="ArgumentException"></exception>
  140. public void SaveSession(string path)
  141. {
  142. GetSessionState().Save(path);
  143. }
  144. /// <summary>
  145. /// Get the session state.
  146. /// </summary>
  147. /// <returns>SessionState object representing session state in-memory</returns>
  148. public SessionState GetSessionState()
  149. {
  150. return new SessionState(
  151. Executor.Context.GetState(),
  152. ((StatefulExecutorBase)Executor).GetStateData(),
  153. History,
  154. InputTransformPipeline,
  155. OutputTransform,
  156. HistoryTransform);
  157. }
  158. /// <summary>
  159. /// Load a session from a session state.
  160. /// </summary>
  161. /// <param name="state"></param>
  162. /// <returns></returns>
  163. /// <exception cref="ArgumentException"></exception>
  164. public void LoadSession(SessionState state)
  165. {
  166. if (Executor is StatefulExecutorBase statefulExecutor)
  167. {
  168. statefulExecutor.LoadState(state.ExecutorState);
  169. }
  170. else
  171. {
  172. throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
  173. }
  174. Executor.Context.LoadState(state.ContextState);
  175. History = new ChatHistory(state.History);
  176. InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
  177. OutputTransform = state.OutputTransform.Clone();
  178. HistoryTransform = state.HistoryTransform.Clone();
  179. }
  180. /// <summary>
  181. /// Load a session from a directory.
  182. /// </summary>
  183. /// <param name="path"></param>
  184. /// <returns></returns>
  185. /// <exception cref="ArgumentException"></exception>
  186. public void LoadSession(string path)
  187. {
  188. var state = SessionState.Load(path);
  189. // Handle non-polymorphic serialization of executor state
  190. if (state.ExecutorState is ExecutorBaseState)
  191. {
  192. var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
  193. ((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
  194. }
  195. LoadSession(state);
  196. }
  197. /// <summary>
  198. /// Add a message to the chat history.
  199. /// </summary>
  200. /// <param name="message"></param>
  201. /// <returns></returns>
  202. public ChatSession AddMessage(ChatHistory.Message message)
  203. {
  204. // If current message is a system message, only allow the history to be empty
  205. if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
  206. {
  207. throw new ArgumentException("Cannot add a system message after another message", nameof(message));
  208. }
  209. // If current message is a user message, only allow the history to be empty,
  210. // or the previous message to be a system message or assistant message.
  211. if (message.AuthorRole == AuthorRole.User)
  212. {
  213. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  214. if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
  215. {
  216. throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
  217. }
  218. }
  219. // If the current message is an assistant message,
  220. // the previous message must be a user message.
  221. if (message.AuthorRole == AuthorRole.Assistant)
  222. {
  223. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  224. if (lastMessage is null
  225. || lastMessage.AuthorRole != AuthorRole.User)
  226. {
  227. throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message));
  228. }
  229. }
  230. History.AddMessage(message.AuthorRole, message.Content);
  231. return this;
  232. }
  233. /// <summary>
  234. /// Compute KV cache for the system message and add it to the chat history.
  235. /// </summary>
  236. /// <param name="content"></param>
  237. /// <returns></returns>
  238. public async Task<ChatSession> ProcessSystemMessage(string content)
  239. {
  240. if (Executor is not StatefulExecutorBase statefulExecutor)
  241. {
  242. throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
  243. }
  244. if (History.Messages.Count > 0)
  245. {
  246. throw new ArgumentException("Cannot add a system message after another message", nameof(content));
  247. }
  248. foreach (var inputTransform in InputTransformPipeline)
  249. {
  250. content = inputTransform.Transform(content);
  251. }
  252. await statefulExecutor.PrefillPromptAsync(content);
  253. History.AddMessage(AuthorRole.System, content);
  254. return this;
  255. }
  256. /// <summary>
  257. /// Add a system message to the chat history.
  258. /// </summary>
  259. /// <param name="content"></param>
  260. /// <returns></returns>
  261. public ChatSession AddSystemMessage(string content)
  262. => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
  263. /// <summary>
  264. /// Add an assistant message to the chat history.
  265. /// </summary>
  266. /// <param name="content"></param>
  267. /// <returns></returns>
  268. public ChatSession AddAssistantMessage(string content)
  269. => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  270. /// <summary>
  271. /// Add a user message to the chat history.
  272. /// </summary>
  273. /// <param name="content"></param>
  274. /// <returns></returns>
  275. public ChatSession AddUserMessage(string content)
  276. => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
  277. /// <summary>
  278. /// Remove the last message from the chat history.
  279. /// </summary>
  280. /// <returns></returns>
  281. public ChatSession RemoveLastMessage()
  282. {
  283. History.Messages.RemoveAt(History.Messages.Count - 1);
  284. return this;
  285. }
  286. /// <summary>
  287. /// Replace a user message with a new message and remove all messages after the new message.
  288. /// This is useful when the user wants to edit a message. And regenerate the response.
  289. /// </summary>
  290. /// <param name="oldMessage"></param>
  291. /// <param name="newMessage"></param>
  292. /// <returns></returns>
  293. public ChatSession ReplaceUserMessage(
  294. ChatHistory.Message oldMessage,
  295. ChatHistory.Message newMessage)
  296. {
  297. if (oldMessage.AuthorRole != AuthorRole.User)
  298. {
  299. throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
  300. }
  301. if (newMessage.AuthorRole != AuthorRole.User)
  302. {
  303. throw new ArgumentException("New message must be a user message", nameof(newMessage));
  304. }
  305. int index = History.Messages.IndexOf(oldMessage);
  306. if (index == -1)
  307. {
  308. throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
  309. }
  310. History.Messages[index] = newMessage;
  311. // Remove all message after the new message
  312. History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
  313. return this;
  314. }
  315. /// <summary>
  316. /// Chat with the model.
  317. /// </summary>
  318. /// <param name="message"></param>
  319. /// <param name="inferenceParams"></param>
  320. /// <param name="applyInputTransformPipeline"></param>
  321. /// <param name="cancellationToken"></param>
  322. /// <returns></returns>
  323. /// <exception cref="ArgumentException"></exception>
  324. public async IAsyncEnumerable<string> ChatAsync(
  325. ChatHistory.Message message,
  326. bool applyInputTransformPipeline,
  327. IInferenceParams? inferenceParams = null,
  328. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  329. {
  330. // The message must be a user message
  331. if (message.AuthorRole != AuthorRole.User)
  332. {
  333. throw new ArgumentException("Message must be a user message", nameof(message));
  334. }
  335. // Apply input transform pipeline
  336. if (applyInputTransformPipeline)
  337. {
  338. foreach (var inputTransform in InputTransformPipeline)
  339. {
  340. message.Content = inputTransform.Transform(message.Content);
  341. }
  342. }
  343. // Add the user's message to the history
  344. AddUserMessage(message.Content);
  345. // Prepare prompt variable
  346. string prompt;
  347. // Check if the session history was restored from a previous session
  348. // or added as part of new chat session history.
  349. InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
  350. // If "IsPromptRun" is true, the session was newly started.
  351. if (state.IsPromptRun)
  352. {
  353. // If the session history was added as part of new chat session history,
  354. // convert the complete history includsing system message and manually added history
  355. // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
  356. prompt = HistoryTransform.HistoryToText(History);
  357. }
  358. else
  359. {
  360. // If the session was restored from a previous session,
  361. // convert only the current message to the prompt with the prompt template
  362. // specified in the HistoryTransform class implementation that is provided.
  363. ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
  364. prompt = HistoryTransform.HistoryToText(singleMessageHistory);
  365. }
  366. string assistantMessage = string.Empty;
  367. await foreach (
  368. string textToken
  369. in ChatAsyncInternal(
  370. prompt,
  371. inferenceParams,
  372. cancellationToken))
  373. {
  374. assistantMessage += textToken;
  375. yield return textToken;
  376. }
  377. // Add the assistant message to the history
  378. AddAssistantMessage(assistantMessage);
  379. }
  380. /// <summary>
  381. /// Chat with the model.
  382. /// </summary>
  383. /// <param name="message"></param>
  384. /// <param name="inferenceParams"></param>
  385. /// <param name="cancellationToken"></param>
  386. /// <returns></returns>
  387. public IAsyncEnumerable<string> ChatAsync(
  388. ChatHistory.Message message,
  389. IInferenceParams? inferenceParams = null,
  390. CancellationToken cancellationToken = default)
  391. {
  392. return ChatAsync(
  393. message,
  394. applyInputTransformPipeline: true,
  395. inferenceParams,
  396. cancellationToken);
  397. }
  398. /// <summary>
  399. /// Chat with the model.
  400. /// </summary>
  401. /// <param name="history"></param>
  402. /// <param name="applyInputTransformPipeline"></param>
  403. /// <param name="inferenceParams"></param>
  404. /// <param name="cancellationToken"></param>
  405. /// <returns></returns>
  406. /// <exception cref="ArgumentException"></exception>
  407. public IAsyncEnumerable<string> ChatAsync(
  408. ChatHistory history,
  409. bool applyInputTransformPipeline,
  410. IInferenceParams? inferenceParams = null,
  411. CancellationToken cancellationToken = default)
  412. {
  413. ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
  414. ?? throw new ArgumentException("History must contain at least one message", nameof(history));
  415. foreach (
  416. ChatHistory.Message message
  417. in history.Messages.Take(history.Messages.Count - 1))
  418. {
  419. // Apply input transform pipeline
  420. if (applyInputTransformPipeline
  421. && message.AuthorRole == AuthorRole.User)
  422. {
  423. foreach (
  424. var inputTransform
  425. in InputTransformPipeline)
  426. {
  427. message.Content = inputTransform.Transform(message.Content);
  428. }
  429. }
  430. AddMessage(message);
  431. }
  432. return ChatAsync(
  433. lastMessage,
  434. applyInputTransformPipeline,
  435. inferenceParams,
  436. cancellationToken);
  437. }
  438. /// <summary>
  439. /// Chat with the model.
  440. /// </summary>
  441. /// <param name="history"></param>
  442. /// <param name="inferenceParams"></param>
  443. /// <param name="cancellationToken"></param>
  444. /// <returns></returns>
  445. public IAsyncEnumerable<string> ChatAsync(
  446. ChatHistory history,
  447. IInferenceParams? inferenceParams = null,
  448. CancellationToken cancellationToken = default)
  449. {
  450. return ChatAsync(
  451. history,
  452. applyInputTransformPipeline: true,
  453. inferenceParams,
  454. cancellationToken);
  455. }
  456. /// <summary>
  457. /// Regenerate the last assistant message.
  458. /// </summary>
  459. /// <param name="inferenceParams"></param>
  460. /// <param name="cancellationToken"></param>
  461. /// <returns></returns>
  462. /// <exception cref="InvalidOperationException"></exception>
  463. public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
  464. InferenceParams? inferenceParams = null,
  465. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  466. {
  467. // Make sure the last message is an assistant message (reponse from the LLM).
  468. ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
  469. if (lastAssistantMessage is null
  470. || lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
  471. {
  472. throw new InvalidOperationException("Last message must be an assistant message");
  473. }
  474. // Remove the last assistant message from the history.
  475. RemoveLastMessage();
  476. // Get the last user message.
  477. ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
  478. if (lastUserMessage is null
  479. || lastUserMessage.AuthorRole != AuthorRole.User)
  480. {
  481. throw new InvalidOperationException("Last message must be a user message");
  482. }
  483. // Remove the last user message from the history.
  484. RemoveLastMessage();
  485. // Regenerate the assistant message.
  486. await foreach (
  487. string textToken
  488. in ChatAsync(
  489. lastUserMessage,
  490. applyInputTransformPipeline: false,
  491. inferenceParams,
  492. cancellationToken))
  493. {
  494. yield return textToken;
  495. }
  496. }
  497. private async IAsyncEnumerable<string> ChatAsyncInternal(
  498. string prompt,
  499. IInferenceParams? inferenceParams = null,
  500. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  501. {
  502. var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
  503. await foreach (
  504. string textToken
  505. in OutputTransform
  506. .TransformAsync(results)
  507. .WithCancellation(cancellationToken))
  508. {
  509. yield return textToken;
  510. }
  511. }
  512. }
  513. /// <summary>
  514. /// The state of a chat session in-memory.
  515. /// </summary>
  516. public record SessionState
  517. {
  518. /// <summary>
  519. /// Saved executor state for the session in JSON format.
  520. /// </summary>
  521. public ExecutorBaseState ExecutorState { get; set; }
  522. /// <summary>
  523. /// Saved context state (KV cache) for the session.
  524. /// </summary>
  525. public State ContextState { get; set; }
  526. /// <summary>
  527. /// The input transform pipeline used in this session.
  528. /// </summary>
  529. public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();
  530. /// <summary>
  531. /// The output transform used in this session.
  532. /// </summary>
  533. public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();
  534. /// <summary>
  535. /// The history transform used in this session.
  536. /// </summary>
  537. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  538. /// <summary>
  539. /// The the chat history messages for this session.
  540. /// </summary>
  541. public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();
  542. /// <summary>
  543. /// Create a new session state.
  544. /// </summary>
  545. /// <param name="contextState"></param>
  546. /// <param name="executorState"></param>
  547. /// <param name="history"></param>
  548. /// <param name="inputTransformPipeline"></param>
  549. /// <param name="outputTransform"></param>
  550. /// <param name="historyTransform"></param>
  551. public SessionState(
  552. State contextState, ExecutorBaseState executorState,
  553. ChatHistory history, List<ITextTransform> inputTransformPipeline,
  554. ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
  555. {
  556. ContextState = contextState;
  557. ExecutorState = executorState;
  558. History = history.Messages.ToArray();
  559. InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
  560. OutputTransform = outputTransform.Clone();
  561. HistoryTransform = historyTransform.Clone();
  562. }
  563. /// <summary>
  564. /// Save the session state to folder.
  565. /// </summary>
  566. /// <param name="path"></param>
  567. public void Save(string path)
  568. {
  569. if (string.IsNullOrWhiteSpace(path))
  570. {
  571. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  572. }
  573. if (Directory.Exists(path))
  574. {
  575. Directory.Delete(path, recursive: true);
  576. }
  577. Directory.CreateDirectory(path);
  578. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  579. var bytes = ContextState.ToByteArray();
  580. File.WriteAllBytes(modelStateFilePath, bytes);
  581. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  582. File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
  583. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  584. File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
  585. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  586. File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
  587. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  588. File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
  589. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  590. File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
  591. }
  592. /// <summary>
  593. /// Load the session state from folder.
  594. /// </summary>
  595. /// <param name="path"></param>
  596. /// <returns></returns>
  597. /// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
  598. public static SessionState Load(string path)
  599. {
  600. if (string.IsNullOrWhiteSpace(path))
  601. {
  602. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  603. }
  604. if (!Directory.Exists(path))
  605. {
  606. throw new ArgumentException("Directory does not exist", nameof(path));
  607. }
  608. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  609. var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath));
  610. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  611. var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath))
  612. ?? throw new ArgumentException("Executor state file is invalid", nameof(path));
  613. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  614. string historyJson = File.ReadAllText(historyFilepath);
  615. var history = ChatHistory.FromJson(historyJson)
  616. ?? throw new ArgumentException("History file is invalid", nameof(path));
  617. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  618. ITextTransform[] inputTransforms;
  619. try
  620. {
  621. inputTransforms = File.Exists(inputTransformFilepath) ?
  622. (JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
  623. ?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
  624. : Array.Empty<ITextTransform>();
  625. }
  626. catch (JsonException)
  627. {
  628. throw new ArgumentException("Input transform file is invalid", nameof(path));
  629. }
  630. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  631. ITextStreamTransform outputTransform;
  632. try
  633. {
  634. outputTransform = File.Exists(outputTransformFilepath) ?
  635. (JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
  636. ?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
  637. : new LLamaTransforms.EmptyTextOutputStreamTransform();
  638. }
  639. catch (JsonException)
  640. {
  641. throw new ArgumentException("Output transform file is invalid", nameof(path));
  642. }
  643. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  644. IHistoryTransform historyTransform;
  645. try
  646. {
  647. historyTransform = File.Exists(historyTransformFilepath) ?
  648. (JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
  649. ?? throw new ArgumentException("History transform file is invalid", nameof(path)))
  650. : new LLamaTransforms.DefaultHistoryTransform();
  651. }
  652. catch (JsonException)
  653. {
  654. throw new ArgumentException("History transform file is invalid", nameof(path));
  655. }
  656. return new SessionState(
  657. contextState,
  658. executorState,
  659. history,
  660. inputTransforms.ToList(),
  661. outputTransform,
  662. historyTransform);
  663. }
  664. }