You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Runtime.CompilerServices;
  6. using System.Text.Json;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Abstractions;
  10. using LLama.Common;
  11. using static LLama.InteractiveExecutor;
  12. using static LLama.LLamaContext;
  13. using static LLama.StatefulExecutorBase;
  14. namespace LLama;
  15. /// <summary>
  16. /// The main chat session class.
  17. /// </summary>
  18. public class ChatSession
  19. {
  20. /// <summary>
  21. /// The filename for the serialized model state (KV cache, etc).
  22. /// </summary>
  23. public const string MODEL_STATE_FILENAME = "ModelState.st";
  24. /// <summary>
  25. /// The filename for the serialized executor state.
  26. /// </summary>
  27. public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
  28. /// <summary>
  29. /// The filename for the serialized chat history.
  30. /// </summary>
  31. public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
  32. /// <summary>
  33. /// The filename for the serialized input transform pipeline.
  34. /// </summary>
  35. public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
  36. /// <summary>
  37. /// The filename for the serialized output transform.
  38. /// </summary>
  39. public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
  40. /// <summary>
  41. /// The filename for the serialized history transform.
  42. /// </summary>
  43. public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
  44. /// <summary>
  45. /// The executor for this session.
  46. /// </summary>
  47. public ILLamaExecutor Executor { get; private set; }
  48. /// <summary>
  49. /// The chat history for this session.
  50. /// </summary>
  51. public ChatHistory History { get; private set; } = new();
  52. /// <summary>
  53. /// The history transform used in this session.
  54. /// </summary>
  55. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  56. /// <summary>
  57. /// The input transform pipeline used in this session.
  58. /// </summary>
  59. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  60. /// <summary>
  61. /// The output transform used in this session.
  62. /// </summary>
  63. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  64. /// <summary>
  65. /// Create a new chat session and preprocess history.
  66. /// </summary>
  67. /// <param name="executor">The executor for this session</param>
  68. /// <param name="history">History for this session</param>
  69. /// <returns></returns>
  70. public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
  71. ILLamaExecutor executor, ChatHistory history)
  72. {
  73. if (executor is not StatefulExecutorBase statefulExecutor)
  74. {
  75. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  76. }
  77. var session = new ChatSession(executor, history);
  78. await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
  79. return session;
  80. }
  81. /// <summary>
  82. /// Create a new chat session.
  83. /// </summary>
  84. /// <param name="executor">The executor for this session</param>
  85. public ChatSession(ILLamaExecutor executor)
  86. {
  87. // Check if executor has StatefulExecutorBase as base class
  88. if (executor is not StatefulExecutorBase)
  89. {
  90. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  91. }
  92. Executor = executor;
  93. }
  94. /// <summary>
  95. /// Create a new chat session with a custom history.
  96. /// </summary>
  97. /// <param name="executor"></param>
  98. /// <param name="history"></param>
  99. public ChatSession(ILLamaExecutor executor, ChatHistory history)
  100. : this(executor)
  101. {
  102. History = history;
  103. }
  104. /// <summary>
  105. /// Use a custom history transform.
  106. /// </summary>
  107. /// <param name="transform"></param>
  108. /// <returns></returns>
  109. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  110. {
  111. HistoryTransform = transform;
  112. return this;
  113. }
  114. /// <summary>
  115. /// Add a text transform to the input transform pipeline.
  116. /// </summary>
  117. /// <param name="transform"></param>
  118. /// <returns></returns>
  119. public ChatSession AddInputTransform(ITextTransform transform)
  120. {
  121. InputTransformPipeline.Add(transform);
  122. return this;
  123. }
  124. /// <summary>
  125. /// Use a custom output transform.
  126. /// </summary>
  127. /// <param name="transform"></param>
  128. /// <returns></returns>
  129. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  130. {
  131. OutputTransform = transform;
  132. return this;
  133. }
  134. /// <summary>
  135. /// Save a session from a directory.
  136. /// </summary>
  137. /// <param name="path"></param>
  138. /// <returns></returns>
  139. /// <exception cref="ArgumentException"></exception>
  140. public void SaveSession(string path)
  141. {
  142. GetSessionState().Save(path);
  143. }
  144. /// <summary>
  145. /// Get the session state.
  146. /// </summary>
  147. /// <returns>SessionState object representing session state in-memory</returns>
  148. public SessionState GetSessionState()
  149. {
  150. var executorState = ((StatefulExecutorBase)Executor).GetStateData();
  151. return new SessionState(
  152. executorState.PastTokensCount > 0
  153. ? Executor.Context.GetState() : null,
  154. executorState,
  155. History,
  156. InputTransformPipeline,
  157. OutputTransform,
  158. HistoryTransform);
  159. }
  160. /// <summary>
  161. /// Load a session from a session state.
  162. /// </summary>
  163. /// <param name="state"></param>
  164. /// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
  165. /// <returns></returns>
  166. /// <exception cref="ArgumentException"></exception>
  167. public void LoadSession(SessionState state, bool loadTransforms = true)
  168. {
  169. if (Executor is StatefulExecutorBase statefulExecutor)
  170. {
  171. if (state.ExecutorState is not null)
  172. {
  173. statefulExecutor.LoadState(state.ExecutorState);
  174. }
  175. }
  176. if (state.ContextState is null)
  177. {
  178. Executor.Context.NativeHandle.KvCacheClear();
  179. }
  180. else
  181. {
  182. Executor.Context.LoadState(state.ContextState);
  183. }
  184. History = new ChatHistory(state.History);
  185. if (loadTransforms)
  186. {
  187. InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
  188. OutputTransform = state.OutputTransform.Clone();
  189. HistoryTransform = state.HistoryTransform.Clone();
  190. }
  191. }
  192. /// <summary>
  193. /// Load a session from a directory.
  194. /// </summary>
  195. /// <param name="path"></param>
  196. /// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
  197. /// <returns></returns>
  198. /// <exception cref="ArgumentException"></exception>
  199. public void LoadSession(string path, bool loadTransforms = true)
  200. {
  201. var state = SessionState.Load(path);
  202. // Handle non-polymorphic serialization of executor state
  203. if (state.ExecutorState is null)
  204. {
  205. var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
  206. ((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
  207. }
  208. LoadSession(state, loadTransforms);
  209. }
  210. /// <summary>
  211. /// Add a message to the chat history.
  212. /// </summary>
  213. /// <param name="message"></param>
  214. /// <returns></returns>
  215. public ChatSession AddMessage(ChatHistory.Message message)
  216. {
  217. // If current message is a system message, only allow the history to be empty
  218. if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
  219. {
  220. throw new ArgumentException("Cannot add a system message after another message", nameof(message));
  221. }
  222. // If current message is a user message, only allow the history to be empty,
  223. // or the previous message to be a system message or assistant message.
  224. if (message.AuthorRole == AuthorRole.User)
  225. {
  226. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  227. if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
  228. {
  229. throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
  230. }
  231. }
  232. // If the current message is an assistant message,
  233. // the previous message must be a user message.
  234. if (message.AuthorRole == AuthorRole.Assistant)
  235. {
  236. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  237. if (lastMessage is null
  238. || lastMessage.AuthorRole != AuthorRole.User)
  239. {
  240. throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message));
  241. }
  242. }
  243. History.AddMessage(message.AuthorRole, message.Content);
  244. return this;
  245. }
  246. /// <summary>
  247. /// Add a system message to the chat history.
  248. /// </summary>
  249. /// <param name="content"></param>
  250. /// <returns></returns>
  251. public ChatSession AddSystemMessage(string content)
  252. => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
  253. /// <summary>
  254. /// Add an assistant message to the chat history.
  255. /// </summary>
  256. /// <param name="content"></param>
  257. /// <returns></returns>
  258. public ChatSession AddAssistantMessage(string content)
  259. => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  260. /// <summary>
  261. /// Add a user message to the chat history.
  262. /// </summary>
  263. /// <param name="content"></param>
  264. /// <returns></returns>
  265. public ChatSession AddUserMessage(string content)
  266. => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
  267. /// <summary>
  268. /// Remove the last message from the chat history.
  269. /// </summary>
  270. /// <returns></returns>
  271. public ChatSession RemoveLastMessage()
  272. {
  273. History.Messages.RemoveAt(History.Messages.Count - 1);
  274. return this;
  275. }
  276. /// <summary>
  277. /// Compute KV cache for the message and add it to the chat history.
  278. /// </summary>
  279. /// <param name="message"></param>
  280. /// <returns></returns>
  281. public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
  282. {
  283. if (Executor is not StatefulExecutorBase statefulExecutor)
  284. {
  285. throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
  286. }
  287. AddMessage(message);
  288. var content = message.Content;
  289. if (message.AuthorRole != AuthorRole.Assistant)
  290. {
  291. foreach (var inputTransform in InputTransformPipeline)
  292. {
  293. content = inputTransform.Transform(content);
  294. }
  295. }
  296. await statefulExecutor.PrefillPromptAsync(content);
  297. return this;
  298. }
  299. /// <summary>
  300. /// Compute KV cache for the system message and add it to the chat history.
  301. /// </summary>
  302. public Task<ChatSession> AddAndProcessSystemMessage(string content)
  303. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
  304. /// <summary>
  305. /// Compute KV cache for the user message and add it to the chat history.
  306. /// </summary>
  307. public Task<ChatSession> AddAndProcessUserMessage(string content)
  308. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
  309. /// <summary>
  310. /// Compute KV cache for the assistant message and add it to the chat history.
  311. /// </summary>
  312. public Task<ChatSession> AddAndProcessAssistantMessage(string content)
  313. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  314. /// <summary>
  315. /// Replace a user message with a new message and remove all messages after the new message.
  316. /// This is useful when the user wants to edit a message. And regenerate the response.
  317. /// </summary>
  318. /// <param name="oldMessage"></param>
  319. /// <param name="newMessage"></param>
  320. /// <returns></returns>
  321. public ChatSession ReplaceUserMessage(
  322. ChatHistory.Message oldMessage,
  323. ChatHistory.Message newMessage)
  324. {
  325. if (oldMessage.AuthorRole != AuthorRole.User)
  326. {
  327. throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
  328. }
  329. if (newMessage.AuthorRole != AuthorRole.User)
  330. {
  331. throw new ArgumentException("New message must be a user message", nameof(newMessage));
  332. }
  333. int index = History.Messages.IndexOf(oldMessage);
  334. if (index == -1)
  335. {
  336. throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
  337. }
  338. History.Messages[index] = newMessage;
  339. // Remove all message after the new message
  340. History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
  341. return this;
  342. }
  343. /// <summary>
  344. /// Chat with the model.
  345. /// </summary>
  346. /// <param name="message"></param>
  347. /// <param name="inferenceParams"></param>
  348. /// <param name="applyInputTransformPipeline"></param>
  349. /// <param name="cancellationToken"></param>
  350. /// <returns></returns>
  351. /// <exception cref="ArgumentException"></exception>
  352. public async IAsyncEnumerable<string> ChatAsync(
  353. ChatHistory.Message message,
  354. bool applyInputTransformPipeline,
  355. IInferenceParams? inferenceParams = null,
  356. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  357. {
  358. // The message must be a user message
  359. if (message.AuthorRole != AuthorRole.User)
  360. {
  361. throw new ArgumentException("Message must be a user message", nameof(message));
  362. }
  363. // Apply input transform pipeline
  364. if (applyInputTransformPipeline)
  365. {
  366. foreach (var inputTransform in InputTransformPipeline)
  367. {
  368. message.Content = inputTransform.Transform(message.Content);
  369. }
  370. }
  371. // Add the user's message to the history
  372. AddUserMessage(message.Content);
  373. // Prepare prompt variable
  374. string prompt;
  375. // Check if the session history was restored from a previous session
  376. // or added as part of new chat session history.
  377. InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
  378. // If "IsPromptRun" is true, the session was newly started.
  379. if (state.IsPromptRun)
  380. {
  381. // If the session history was added as part of new chat session history,
  382. // convert the complete history includsing system message and manually added history
  383. // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
  384. prompt = HistoryTransform.HistoryToText(History);
  385. }
  386. else
  387. {
  388. // If the session was restored from a previous session,
  389. // convert only the current message to the prompt with the prompt template
  390. // specified in the HistoryTransform class implementation that is provided.
  391. ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
  392. prompt = HistoryTransform.HistoryToText(singleMessageHistory);
  393. }
  394. string assistantMessage = string.Empty;
  395. await foreach (
  396. string textToken
  397. in ChatAsyncInternal(
  398. prompt,
  399. inferenceParams,
  400. cancellationToken))
  401. {
  402. assistantMessage += textToken;
  403. yield return textToken;
  404. }
  405. // Add the assistant message to the history
  406. AddAssistantMessage(assistantMessage);
  407. }
  408. /// <summary>
  409. /// Chat with the model.
  410. /// </summary>
  411. /// <param name="message"></param>
  412. /// <param name="inferenceParams"></param>
  413. /// <param name="cancellationToken"></param>
  414. /// <returns></returns>
  415. public IAsyncEnumerable<string> ChatAsync(
  416. ChatHistory.Message message,
  417. IInferenceParams? inferenceParams = null,
  418. CancellationToken cancellationToken = default)
  419. {
  420. return ChatAsync(
  421. message,
  422. applyInputTransformPipeline: true,
  423. inferenceParams,
  424. cancellationToken);
  425. }
  426. /// <summary>
  427. /// Chat with the model.
  428. /// </summary>
  429. /// <param name="history"></param>
  430. /// <param name="applyInputTransformPipeline"></param>
  431. /// <param name="inferenceParams"></param>
  432. /// <param name="cancellationToken"></param>
  433. /// <returns></returns>
  434. /// <exception cref="ArgumentException"></exception>
  435. public IAsyncEnumerable<string> ChatAsync(
  436. ChatHistory history,
  437. bool applyInputTransformPipeline,
  438. IInferenceParams? inferenceParams = null,
  439. CancellationToken cancellationToken = default)
  440. {
  441. ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
  442. ?? throw new ArgumentException("History must contain at least one message", nameof(history));
  443. foreach (
  444. ChatHistory.Message message
  445. in history.Messages.Take(history.Messages.Count - 1))
  446. {
  447. // Apply input transform pipeline
  448. if (applyInputTransformPipeline
  449. && message.AuthorRole == AuthorRole.User)
  450. {
  451. foreach (
  452. var inputTransform
  453. in InputTransformPipeline)
  454. {
  455. message.Content = inputTransform.Transform(message.Content);
  456. }
  457. }
  458. AddMessage(message);
  459. }
  460. return ChatAsync(
  461. lastMessage,
  462. applyInputTransformPipeline,
  463. inferenceParams,
  464. cancellationToken);
  465. }
  466. /// <summary>
  467. /// Chat with the model.
  468. /// </summary>
  469. /// <param name="history"></param>
  470. /// <param name="inferenceParams"></param>
  471. /// <param name="cancellationToken"></param>
  472. /// <returns></returns>
  473. public IAsyncEnumerable<string> ChatAsync(
  474. ChatHistory history,
  475. IInferenceParams? inferenceParams = null,
  476. CancellationToken cancellationToken = default)
  477. {
  478. return ChatAsync(
  479. history,
  480. applyInputTransformPipeline: true,
  481. inferenceParams,
  482. cancellationToken);
  483. }
  484. /// <summary>
  485. /// Regenerate the last assistant message.
  486. /// </summary>
  487. /// <param name="inferenceParams"></param>
  488. /// <param name="cancellationToken"></param>
  489. /// <returns></returns>
  490. /// <exception cref="InvalidOperationException"></exception>
  491. public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
  492. InferenceParams? inferenceParams = null,
  493. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  494. {
  495. // Make sure the last message is an assistant message (reponse from the LLM).
  496. ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
  497. if (lastAssistantMessage is null
  498. || lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
  499. {
  500. throw new InvalidOperationException("Last message must be an assistant message");
  501. }
  502. // Remove the last assistant message from the history.
  503. RemoveLastMessage();
  504. // Get the last user message.
  505. ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
  506. if (lastUserMessage is null
  507. || lastUserMessage.AuthorRole != AuthorRole.User)
  508. {
  509. throw new InvalidOperationException("Last message must be a user message");
  510. }
  511. // Remove the last user message from the history.
  512. RemoveLastMessage();
  513. // Regenerate the assistant message.
  514. await foreach (
  515. string textToken
  516. in ChatAsync(
  517. lastUserMessage,
  518. applyInputTransformPipeline: false,
  519. inferenceParams,
  520. cancellationToken))
  521. {
  522. yield return textToken;
  523. }
  524. }
  525. private async IAsyncEnumerable<string> ChatAsyncInternal(
  526. string prompt,
  527. IInferenceParams? inferenceParams = null,
  528. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  529. {
  530. var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
  531. await foreach (
  532. string textToken
  533. in OutputTransform
  534. .TransformAsync(results)
  535. .WithCancellation(cancellationToken))
  536. {
  537. yield return textToken;
  538. }
  539. }
  540. }
  541. /// <summary>
  542. /// The state of a chat session in-memory.
  543. /// </summary>
  544. public record SessionState
  545. {
  546. /// <summary>
  547. /// Saved executor state for the session in JSON format.
  548. /// </summary>
  549. public ExecutorBaseState? ExecutorState { get; set; }
  550. /// <summary>
  551. /// Saved context state (KV cache) for the session.
  552. /// </summary>
  553. public State? ContextState { get; set; }
  554. /// <summary>
  555. /// The input transform pipeline used in this session.
  556. /// </summary>
  557. public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();
  558. /// <summary>
  559. /// The output transform used in this session.
  560. /// </summary>
  561. public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();
  562. /// <summary>
  563. /// The history transform used in this session.
  564. /// </summary>
  565. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  566. /// <summary>
  567. /// The the chat history messages for this session.
  568. /// </summary>
  569. public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();
  570. /// <summary>
  571. /// Create a new session state.
  572. /// </summary>
  573. /// <param name="contextState"></param>
  574. /// <param name="executorState"></param>
  575. /// <param name="history"></param>
  576. /// <param name="inputTransformPipeline"></param>
  577. /// <param name="outputTransform"></param>
  578. /// <param name="historyTransform"></param>
  579. public SessionState(
  580. State? contextState, ExecutorBaseState executorState,
  581. ChatHistory history, List<ITextTransform> inputTransformPipeline,
  582. ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
  583. {
  584. ContextState = contextState;
  585. ExecutorState = executorState;
  586. History = history.Messages.ToArray();
  587. InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
  588. OutputTransform = outputTransform.Clone();
  589. HistoryTransform = historyTransform.Clone();
  590. }
  591. /// <summary>
  592. /// Save the session state to folder.
  593. /// </summary>
  594. /// <param name="path"></param>
  595. public void Save(string path)
  596. {
  597. if (string.IsNullOrWhiteSpace(path))
  598. {
  599. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  600. }
  601. if (Directory.Exists(path))
  602. {
  603. Directory.Delete(path, recursive: true);
  604. }
  605. Directory.CreateDirectory(path);
  606. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  607. var bytes = ContextState?.ToByteArray();
  608. if (bytes is not null)
  609. {
  610. File.WriteAllBytes(modelStateFilePath, bytes);
  611. }
  612. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  613. File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
  614. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  615. File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
  616. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  617. File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
  618. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  619. File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
  620. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  621. File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
  622. }
  623. /// <summary>
  624. /// Load the session state from folder.
  625. /// </summary>
  626. /// <param name="path"></param>
  627. /// <returns></returns>
  628. /// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
  629. public static SessionState Load(string path)
  630. {
  631. if (string.IsNullOrWhiteSpace(path))
  632. {
  633. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  634. }
  635. if (!Directory.Exists(path))
  636. {
  637. throw new ArgumentException("Directory does not exist", nameof(path));
  638. }
  639. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  640. var contextState = File.Exists(modelStateFilePath) ?
  641. State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
  642. : null;
  643. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  644. var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));
  645. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  646. string historyJson = File.ReadAllText(historyFilepath);
  647. var history = ChatHistory.FromJson(historyJson)
  648. ?? throw new ArgumentException("History file is invalid", nameof(path));
  649. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  650. ITextTransform[] inputTransforms;
  651. try
  652. {
  653. inputTransforms = File.Exists(inputTransformFilepath) ?
  654. (JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
  655. ?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
  656. : Array.Empty<ITextTransform>();
  657. }
  658. catch (JsonException)
  659. {
  660. throw new ArgumentException("Input transform file is invalid", nameof(path));
  661. }
  662. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  663. ITextStreamTransform outputTransform;
  664. try
  665. {
  666. outputTransform = File.Exists(outputTransformFilepath) ?
  667. (JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
  668. ?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
  669. : new LLamaTransforms.EmptyTextOutputStreamTransform();
  670. }
  671. catch (JsonException)
  672. {
  673. throw new ArgumentException("Output transform file is invalid", nameof(path));
  674. }
  675. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  676. IHistoryTransform historyTransform;
  677. try
  678. {
  679. historyTransform = File.Exists(historyTransformFilepath) ?
  680. (JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
  681. ?? throw new ArgumentException("History transform file is invalid", nameof(path)))
  682. : new LLamaTransforms.DefaultHistoryTransform();
  683. }
  684. catch (JsonException)
  685. {
  686. throw new ArgumentException("History transform file is invalid", nameof(path));
  687. }
  688. return new SessionState(
  689. contextState,
  690. executorState,
  691. history,
  692. inputTransforms.ToList(),
  693. outputTransform,
  694. historyTransform);
  695. }
  696. }