You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 29 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Runtime.CompilerServices;
  6. using System.Text.Json;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLama.Abstractions;
  10. using LLama.Common;
  11. using static LLama.InteractiveExecutor;
  12. using static LLama.LLamaContext;
  13. using static LLama.StatefulExecutorBase;
  14. namespace LLama;
  15. /// <summary>
  16. /// The main chat session class.
  17. /// </summary>
  18. public class ChatSession
  19. {
  20. /// <summary>
  21. /// The filename for the serialized model state (KV cache, etc).
  22. /// </summary>
  23. public const string MODEL_STATE_FILENAME = "ModelState.st";
  24. /// <summary>
  25. /// The filename for the serialized executor state.
  26. /// </summary>
  27. public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
  28. /// <summary>
  29. /// The filename for the serialized chat history.
  30. /// </summary>
  31. public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
  32. /// <summary>
  33. /// The filename for the serialized input transform pipeline.
  34. /// </summary>
  35. public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
  36. /// <summary>
  37. /// The filename for the serialized output transform.
  38. /// </summary>
  39. public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
  40. /// <summary>
  41. /// The filename for the serialized history transform.
  42. /// </summary>
  43. public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
  44. /// <summary>
  45. /// The executor for this session.
  46. /// </summary>
  47. public ILLamaExecutor Executor { get; private set; }
  48. /// <summary>
  49. /// The chat history for this session.
  50. /// </summary>
  51. public ChatHistory History { get; private set; } = new();
  52. /// <summary>
  53. /// The history transform used in this session.
  54. /// </summary>
  55. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  56. /// <summary>
  57. /// The input transform pipeline used in this session.
  58. /// </summary>
  59. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  60. /// <summary>
  61. /// The output transform used in this session.
  62. /// </summary>
  63. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  64. /// <summary>
  65. /// Create a new chat session and preprocess history.
  66. /// </summary>
  67. /// <param name="executor">The executor for this session</param>
  68. /// <param name="history">History for this session</param>
  69. /// <param name="transform">History Transform for this session</param>
  70. /// <returns>A new chat session.</returns>
  71. public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
  72. ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
  73. {
  74. if (executor is not StatefulExecutorBase statefulExecutor)
  75. {
  76. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  77. }
  78. var session = new ChatSession(executor, history);
  79. if (transform != null)
  80. {
  81. session = session.WithHistoryTransform(transform);
  82. }
  83. await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
  84. return session;
  85. }
  86. /// <summary>
  87. /// Create a new chat session.
  88. /// </summary>
  89. /// <param name="executor">The executor for this session</param>
  90. public ChatSession(ILLamaExecutor executor)
  91. {
  92. // Check if executor has StatefulExecutorBase as base class
  93. if (executor is not StatefulExecutorBase)
  94. {
  95. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  96. }
  97. Executor = executor;
  98. }
  99. /// <summary>
  100. /// Create a new chat session with a custom history.
  101. /// </summary>
  102. /// <param name="executor"></param>
  103. /// <param name="history"></param>
  104. public ChatSession(ILLamaExecutor executor, ChatHistory history)
  105. : this(executor)
  106. {
  107. History = history;
  108. }
  109. /// <summary>
  110. /// Use a custom history transform.
  111. /// </summary>
  112. /// <param name="transform"></param>
  113. /// <returns></returns>
  114. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  115. {
  116. HistoryTransform = transform;
  117. return this;
  118. }
  119. /// <summary>
  120. /// Add a text transform to the input transform pipeline.
  121. /// </summary>
  122. /// <param name="transform"></param>
  123. /// <returns></returns>
  124. public ChatSession AddInputTransform(ITextTransform transform)
  125. {
  126. InputTransformPipeline.Add(transform);
  127. return this;
  128. }
  129. /// <summary>
  130. /// Use a custom output transform.
  131. /// </summary>
  132. /// <param name="transform"></param>
  133. /// <returns></returns>
  134. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  135. {
  136. OutputTransform = transform;
  137. return this;
  138. }
  139. /// <summary>
  140. /// Save a session from a directory.
  141. /// </summary>
  142. /// <param name="path"></param>
  143. /// <returns></returns>
  144. /// <exception cref="ArgumentException"></exception>
  145. public void SaveSession(string path)
  146. {
  147. GetSessionState().Save(path);
  148. }
  149. /// <summary>
  150. /// Get the session state.
  151. /// </summary>
  152. /// <returns>SessionState object representing session state in-memory</returns>
  153. public SessionState GetSessionState()
  154. {
  155. var executorState = ((StatefulExecutorBase)Executor).GetStateData();
  156. return new SessionState(
  157. executorState.PastTokensCount > 0
  158. ? Executor.Context.GetState() : null,
  159. executorState,
  160. History,
  161. InputTransformPipeline,
  162. OutputTransform,
  163. HistoryTransform);
  164. }
  165. /// <summary>
  166. /// Load a session from a session state.
  167. /// </summary>
  168. /// <param name="state"></param>
  169. /// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
  170. /// <returns></returns>
  171. /// <exception cref="ArgumentException"></exception>
  172. public void LoadSession(SessionState state, bool loadTransforms = true)
  173. {
  174. if (Executor is StatefulExecutorBase statefulExecutor)
  175. {
  176. if (state.ExecutorState is not null)
  177. {
  178. statefulExecutor.LoadState(state.ExecutorState);
  179. }
  180. }
  181. if (state.ContextState is null)
  182. {
  183. Executor.Context.NativeHandle.KvCacheClear();
  184. }
  185. else
  186. {
  187. Executor.Context.LoadState(state.ContextState);
  188. }
  189. History = new ChatHistory(state.History);
  190. if (loadTransforms)
  191. {
  192. InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
  193. OutputTransform = state.OutputTransform.Clone();
  194. HistoryTransform = state.HistoryTransform.Clone();
  195. }
  196. }
  197. /// <summary>
  198. /// Load a session from a directory.
  199. /// </summary>
  200. /// <param name="path"></param>
  201. /// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
  202. /// <returns></returns>
  203. /// <exception cref="ArgumentException"></exception>
  204. public void LoadSession(string path, bool loadTransforms = true)
  205. {
  206. var state = SessionState.Load(path);
  207. // Handle non-polymorphic serialization of executor state
  208. if (state.ExecutorState is null)
  209. {
  210. var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
  211. ((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
  212. }
  213. LoadSession(state, loadTransforms);
  214. }
  215. /// <summary>
  216. /// Add a message to the chat history.
  217. /// </summary>
  218. /// <param name="message"></param>
  219. /// <returns></returns>
  220. public ChatSession AddMessage(ChatHistory.Message message)
  221. {
  222. // If current message is a system message, only allow the history to be empty
  223. if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
  224. {
  225. throw new ArgumentException("Cannot add a system message after another message", nameof(message));
  226. }
  227. // If current message is a user message, only allow the history to be empty,
  228. // or the previous message to be a system message or assistant message.
  229. if (message.AuthorRole == AuthorRole.User)
  230. {
  231. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  232. if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
  233. {
  234. throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
  235. }
  236. }
  237. // If the current message is an assistant message,
  238. // the previous message must be a user message.
  239. if (message.AuthorRole == AuthorRole.Assistant)
  240. {
  241. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  242. if (lastMessage is null
  243. || lastMessage.AuthorRole != AuthorRole.User)
  244. {
  245. throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message));
  246. }
  247. }
  248. History.AddMessage(message.AuthorRole, message.Content);
  249. return this;
  250. }
  251. /// <summary>
  252. /// Add a system message to the chat history.
  253. /// </summary>
  254. /// <param name="content"></param>
  255. /// <returns></returns>
  256. public ChatSession AddSystemMessage(string content)
  257. => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
  258. /// <summary>
  259. /// Add an assistant message to the chat history.
  260. /// </summary>
  261. /// <param name="content"></param>
  262. /// <returns></returns>
  263. public ChatSession AddAssistantMessage(string content)
  264. => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  265. /// <summary>
  266. /// Add a user message to the chat history.
  267. /// </summary>
  268. /// <param name="content"></param>
  269. /// <returns></returns>
  270. public ChatSession AddUserMessage(string content)
  271. => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
  272. /// <summary>
  273. /// Remove the last message from the chat history.
  274. /// </summary>
  275. /// <returns></returns>
  276. public ChatSession RemoveLastMessage()
  277. {
  278. History.Messages.RemoveAt(History.Messages.Count - 1);
  279. return this;
  280. }
  281. /// <summary>
  282. /// Compute KV cache for the message and add it to the chat history.
  283. /// </summary>
  284. /// <param name="message"></param>
  285. /// <returns></returns>
  286. public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
  287. {
  288. if (Executor is not StatefulExecutorBase statefulExecutor)
  289. {
  290. throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
  291. }
  292. AddMessage(message);
  293. var content = message.Content;
  294. if (message.AuthorRole != AuthorRole.Assistant)
  295. {
  296. foreach (var inputTransform in InputTransformPipeline)
  297. {
  298. content = inputTransform.Transform(content);
  299. }
  300. }
  301. await statefulExecutor.PrefillPromptAsync(content);
  302. return this;
  303. }
  304. /// <summary>
  305. /// Compute KV cache for the system message and add it to the chat history.
  306. /// </summary>
  307. public Task<ChatSession> AddAndProcessSystemMessage(string content)
  308. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
  309. /// <summary>
  310. /// Compute KV cache for the user message and add it to the chat history.
  311. /// </summary>
  312. public Task<ChatSession> AddAndProcessUserMessage(string content)
  313. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
  314. /// <summary>
  315. /// Compute KV cache for the assistant message and add it to the chat history.
  316. /// </summary>
  317. public Task<ChatSession> AddAndProcessAssistantMessage(string content)
  318. => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  319. /// <summary>
  320. /// Replace a user message with a new message and remove all messages after the new message.
  321. /// This is useful when the user wants to edit a message. And regenerate the response.
  322. /// </summary>
  323. /// <param name="oldMessage"></param>
  324. /// <param name="newMessage"></param>
  325. /// <returns></returns>
  326. public ChatSession ReplaceUserMessage(
  327. ChatHistory.Message oldMessage,
  328. ChatHistory.Message newMessage)
  329. {
  330. if (oldMessage.AuthorRole != AuthorRole.User)
  331. {
  332. throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
  333. }
  334. if (newMessage.AuthorRole != AuthorRole.User)
  335. {
  336. throw new ArgumentException("New message must be a user message", nameof(newMessage));
  337. }
  338. int index = History.Messages.IndexOf(oldMessage);
  339. if (index == -1)
  340. {
  341. throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
  342. }
  343. History.Messages[index] = newMessage;
  344. // Remove all message after the new message
  345. History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
  346. return this;
  347. }
  348. /// <summary>
  349. /// Chat with the model.
  350. /// </summary>
  351. /// <param name="message"></param>
  352. /// <param name="inferenceParams"></param>
  353. /// <param name="applyInputTransformPipeline"></param>
  354. /// <param name="cancellationToken"></param>
  355. /// <returns></returns>
  356. /// <exception cref="ArgumentException"></exception>
  357. public async IAsyncEnumerable<string> ChatAsync(
  358. ChatHistory.Message message,
  359. bool applyInputTransformPipeline,
  360. IInferenceParams? inferenceParams = null,
  361. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  362. {
  363. // The message must be a user message
  364. if (message.AuthorRole != AuthorRole.User)
  365. {
  366. throw new ArgumentException("Message must be a user message", nameof(message));
  367. }
  368. // Apply input transform pipeline
  369. if (applyInputTransformPipeline)
  370. {
  371. foreach (var inputTransform in InputTransformPipeline)
  372. {
  373. message.Content = inputTransform.Transform(message.Content);
  374. }
  375. }
  376. // Add the user's message to the history
  377. AddUserMessage(message.Content);
  378. // Prepare prompt variable
  379. string prompt;
  380. // Check if the session history was restored from a previous session
  381. // or added as part of new chat session history.
  382. InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
  383. // If "IsPromptRun" is true, the session was newly started.
  384. if (state.IsPromptRun)
  385. {
  386. // If the session history was added as part of new chat session history,
  387. // convert the complete history includsing system message and manually added history
  388. // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
  389. prompt = HistoryTransform.HistoryToText(History);
  390. }
  391. else
  392. {
  393. // If the session was restored from a previous session,
  394. // convert only the current message to the prompt with the prompt template
  395. // specified in the HistoryTransform class implementation that is provided.
  396. ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
  397. prompt = HistoryTransform.HistoryToText(singleMessageHistory);
  398. }
  399. string assistantMessage = string.Empty;
  400. await foreach (
  401. string textToken
  402. in ChatAsyncInternal(
  403. prompt,
  404. inferenceParams,
  405. cancellationToken))
  406. {
  407. assistantMessage += textToken;
  408. yield return textToken;
  409. }
  410. // Add the assistant message to the history
  411. AddAssistantMessage(assistantMessage);
  412. }
  413. /// <summary>
  414. /// Chat with the model.
  415. /// </summary>
  416. /// <param name="message"></param>
  417. /// <param name="inferenceParams"></param>
  418. /// <param name="cancellationToken"></param>
  419. /// <returns></returns>
  420. public IAsyncEnumerable<string> ChatAsync(
  421. ChatHistory.Message message,
  422. IInferenceParams? inferenceParams = null,
  423. CancellationToken cancellationToken = default)
  424. {
  425. return ChatAsync(
  426. message,
  427. applyInputTransformPipeline: true,
  428. inferenceParams,
  429. cancellationToken);
  430. }
  431. /// <summary>
  432. /// Chat with the model.
  433. /// </summary>
  434. /// <param name="history"></param>
  435. /// <param name="applyInputTransformPipeline"></param>
  436. /// <param name="inferenceParams"></param>
  437. /// <param name="cancellationToken"></param>
  438. /// <returns></returns>
  439. /// <exception cref="ArgumentException"></exception>
  440. public IAsyncEnumerable<string> ChatAsync(
  441. ChatHistory history,
  442. bool applyInputTransformPipeline,
  443. IInferenceParams? inferenceParams = null,
  444. CancellationToken cancellationToken = default)
  445. {
  446. ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
  447. ?? throw new ArgumentException("History must contain at least one message", nameof(history));
  448. foreach (
  449. ChatHistory.Message message
  450. in history.Messages.Take(history.Messages.Count - 1))
  451. {
  452. // Apply input transform pipeline
  453. if (applyInputTransformPipeline
  454. && message.AuthorRole == AuthorRole.User)
  455. {
  456. foreach (
  457. var inputTransform
  458. in InputTransformPipeline)
  459. {
  460. message.Content = inputTransform.Transform(message.Content);
  461. }
  462. }
  463. AddMessage(message);
  464. }
  465. return ChatAsync(
  466. lastMessage,
  467. applyInputTransformPipeline,
  468. inferenceParams,
  469. cancellationToken);
  470. }
  471. /// <summary>
  472. /// Chat with the model.
  473. /// </summary>
  474. /// <param name="history"></param>
  475. /// <param name="inferenceParams"></param>
  476. /// <param name="cancellationToken"></param>
  477. /// <returns></returns>
  478. public IAsyncEnumerable<string> ChatAsync(
  479. ChatHistory history,
  480. IInferenceParams? inferenceParams = null,
  481. CancellationToken cancellationToken = default)
  482. {
  483. return ChatAsync(
  484. history,
  485. applyInputTransformPipeline: true,
  486. inferenceParams,
  487. cancellationToken);
  488. }
  489. /// <summary>
  490. /// Regenerate the last assistant message.
  491. /// </summary>
  492. /// <param name="inferenceParams"></param>
  493. /// <param name="cancellationToken"></param>
  494. /// <returns></returns>
  495. /// <exception cref="InvalidOperationException"></exception>
  496. public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
  497. InferenceParams? inferenceParams = null,
  498. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  499. {
  500. // Make sure the last message is an assistant message (response from the LLM).
  501. ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
  502. if (lastAssistantMessage is null
  503. || lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
  504. {
  505. throw new InvalidOperationException("Last message must be an assistant message");
  506. }
  507. // Remove the last assistant message from the history.
  508. RemoveLastMessage();
  509. // Get the last user message.
  510. ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
  511. if (lastUserMessage is null
  512. || lastUserMessage.AuthorRole != AuthorRole.User)
  513. {
  514. throw new InvalidOperationException("Last message must be a user message");
  515. }
  516. // Remove the last user message from the history.
  517. RemoveLastMessage();
  518. // Regenerate the assistant message.
  519. await foreach (
  520. string textToken
  521. in ChatAsync(
  522. lastUserMessage,
  523. applyInputTransformPipeline: false,
  524. inferenceParams,
  525. cancellationToken))
  526. {
  527. yield return textToken;
  528. }
  529. }
  530. private async IAsyncEnumerable<string> ChatAsyncInternal(
  531. string prompt,
  532. IInferenceParams? inferenceParams = null,
  533. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  534. {
  535. var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
  536. await foreach (
  537. string textToken
  538. in OutputTransform
  539. .TransformAsync(results)
  540. .WithCancellation(cancellationToken))
  541. {
  542. yield return textToken;
  543. }
  544. }
  545. }
  546. /// <summary>
  547. /// The state of a chat session in-memory.
  548. /// </summary>
  549. public record SessionState
  550. {
  551. /// <summary>
  552. /// Saved executor state for the session in JSON format.
  553. /// </summary>
  554. public ExecutorBaseState? ExecutorState { get; set; }
  555. /// <summary>
  556. /// Saved context state (KV cache) for the session.
  557. /// </summary>
  558. public State? ContextState { get; set; }
  559. /// <summary>
  560. /// The input transform pipeline used in this session.
  561. /// </summary>
  562. public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();
  563. /// <summary>
  564. /// The output transform used in this session.
  565. /// </summary>
  566. public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();
  567. /// <summary>
  568. /// The history transform used in this session.
  569. /// </summary>
  570. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  571. /// <summary>
  572. /// The the chat history messages for this session.
  573. /// </summary>
  574. public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();
  575. /// <summary>
  576. /// Create a new session state.
  577. /// </summary>
  578. /// <param name="contextState"></param>
  579. /// <param name="executorState"></param>
  580. /// <param name="history"></param>
  581. /// <param name="inputTransformPipeline"></param>
  582. /// <param name="outputTransform"></param>
  583. /// <param name="historyTransform"></param>
  584. public SessionState(
  585. State? contextState, ExecutorBaseState executorState,
  586. ChatHistory history, List<ITextTransform> inputTransformPipeline,
  587. ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
  588. {
  589. ContextState = contextState;
  590. ExecutorState = executorState;
  591. History = history.Messages.ToArray();
  592. InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
  593. OutputTransform = outputTransform.Clone();
  594. HistoryTransform = historyTransform.Clone();
  595. }
  596. /// <summary>
  597. /// Save the session state to folder.
  598. /// </summary>
  599. /// <param name="path"></param>
  600. public void Save(string path)
  601. {
  602. if (string.IsNullOrWhiteSpace(path))
  603. {
  604. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  605. }
  606. if (Directory.Exists(path))
  607. {
  608. Directory.Delete(path, recursive: true);
  609. }
  610. Directory.CreateDirectory(path);
  611. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  612. var bytes = ContextState?.ToByteArray();
  613. if (bytes is not null)
  614. {
  615. File.WriteAllBytes(modelStateFilePath, bytes);
  616. }
  617. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  618. File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
  619. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  620. File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
  621. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  622. File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
  623. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  624. File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
  625. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  626. File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
  627. }
  628. /// <summary>
  629. /// Load the session state from folder.
  630. /// </summary>
  631. /// <param name="path"></param>
  632. /// <returns></returns>
  633. /// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
  634. public static SessionState Load(string path)
  635. {
  636. if (string.IsNullOrWhiteSpace(path))
  637. {
  638. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  639. }
  640. if (!Directory.Exists(path))
  641. {
  642. throw new ArgumentException("Directory does not exist", nameof(path));
  643. }
  644. string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
  645. var contextState = File.Exists(modelStateFilePath) ?
  646. State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
  647. : null;
  648. string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
  649. var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));
  650. string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
  651. string historyJson = File.ReadAllText(historyFilepath);
  652. var history = ChatHistory.FromJson(historyJson)
  653. ?? throw new ArgumentException("History file is invalid", nameof(path));
  654. string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
  655. ITextTransform[] inputTransforms;
  656. try
  657. {
  658. inputTransforms = File.Exists(inputTransformFilepath) ?
  659. (JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
  660. ?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
  661. : Array.Empty<ITextTransform>();
  662. }
  663. catch (JsonException)
  664. {
  665. throw new ArgumentException("Input transform file is invalid", nameof(path));
  666. }
  667. string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
  668. ITextStreamTransform outputTransform;
  669. try
  670. {
  671. outputTransform = File.Exists(outputTransformFilepath) ?
  672. (JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
  673. ?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
  674. : new LLamaTransforms.EmptyTextOutputStreamTransform();
  675. }
  676. catch (JsonException)
  677. {
  678. throw new ArgumentException("Output transform file is invalid", nameof(path));
  679. }
  680. string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
  681. IHistoryTransform historyTransform;
  682. try
  683. {
  684. historyTransform = File.Exists(historyTransformFilepath) ?
  685. (JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
  686. ?? throw new ArgumentException("History transform file is invalid", nameof(path)))
  687. : new LLamaTransforms.DefaultHistoryTransform();
  688. }
  689. catch (JsonException)
  690. {
  691. throw new ArgumentException("History transform file is invalid", nameof(path));
  692. }
  693. return new SessionState(
  694. contextState,
  695. executorState,
  696. history,
  697. inputTransforms.ToList(),
  698. outputTransform,
  699. historyTransform);
  700. }
  701. }