You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Runtime.CompilerServices;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. using LLama.Abstractions;
  9. using LLama.Common;
  10. using static LLama.InteractiveExecutor;
  11. namespace LLama;
  12. /// <summary>
  13. /// The main chat session class.
  14. /// </summary>
  15. public class ChatSession
  16. {
  17. private const string _modelStateFilename = "ModelState.st";
  18. private const string _executorStateFilename = "ExecutorState.json";
  19. private const string _hsitoryFilename = "ChatHistory.json";
  20. /// <summary>
  21. /// The executor for this session.
  22. /// </summary>
  23. public ILLamaExecutor Executor { get; private set; }
  24. /// <summary>
  25. /// The chat history for this session.
  26. /// </summary>
  27. public ChatHistory History { get; private set; } = new();
  28. /// <summary>
  29. /// The history transform used in this session.
  30. /// </summary>
  31. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  32. /// <summary>
  33. /// The input transform pipeline used in this session.
  34. /// </summary>
  35. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  36. /// <summary>
  37. /// The output transform used in this session.
  38. /// </summary>
  39. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  40. /// <summary>
  41. /// Create a new chat session.
  42. /// </summary>
  43. /// <param name="executor">The executor for this session</param>
  44. public ChatSession(ILLamaExecutor executor)
  45. {
  46. // Check if executor has StatefulExecutorBase as base class
  47. if (executor is not StatefulExecutorBase)
  48. {
  49. throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
  50. }
  51. Executor = executor;
  52. }
  53. /// <summary>
  54. /// Create a new chat session with a custom history.
  55. /// </summary>
  56. /// <param name="executor"></param>
  57. /// <param name="history"></param>
  58. public ChatSession(ILLamaExecutor executor, ChatHistory history)
  59. : this(executor)
  60. {
  61. History = history;
  62. }
  63. /// <summary>
  64. /// Use a custom history transform.
  65. /// </summary>
  66. /// <param name="transform"></param>
  67. /// <returns></returns>
  68. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  69. {
  70. HistoryTransform = transform;
  71. return this;
  72. }
  73. /// <summary>
  74. /// Add a text transform to the input transform pipeline.
  75. /// </summary>
  76. /// <param name="transform"></param>
  77. /// <returns></returns>
  78. public ChatSession AddInputTransform(ITextTransform transform)
  79. {
  80. InputTransformPipeline.Add(transform);
  81. return this;
  82. }
  83. /// <summary>
  84. /// Use a custom output transform.
  85. /// </summary>
  86. /// <param name="transform"></param>
  87. /// <returns></returns>
  88. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  89. {
  90. OutputTransform = transform;
  91. return this;
  92. }
  93. /// <summary>
  94. /// Save a session from a directory.
  95. /// </summary>
  96. /// <param name="path"></param>
  97. /// <returns></returns>
  98. /// <exception cref="ArgumentException"></exception>
  99. public void SaveSession(string path)
  100. {
  101. if (string.IsNullOrWhiteSpace(path))
  102. {
  103. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  104. }
  105. if (Directory.Exists(path))
  106. {
  107. Directory.Delete(path, recursive: true);
  108. }
  109. Directory.CreateDirectory(path);
  110. string modelStateFilePath = Path.Combine(path, _modelStateFilename);
  111. Executor.Context.SaveState(modelStateFilePath);
  112. string executorStateFilepath = Path.Combine(path, _executorStateFilename);
  113. ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);
  114. string historyFilepath = Path.Combine(path, _hsitoryFilename);
  115. File.WriteAllText(historyFilepath, History.ToJson());
  116. }
  117. /// <summary>
  118. /// Load a session from a directory.
  119. /// </summary>
  120. /// <param name="path"></param>
  121. /// <returns></returns>
  122. /// <exception cref="ArgumentException"></exception>
  123. public void LoadSession(string path)
  124. {
  125. if (string.IsNullOrWhiteSpace(path))
  126. {
  127. throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
  128. }
  129. if (!Directory.Exists(path))
  130. {
  131. throw new ArgumentException("Directory does not exist", nameof(path));
  132. }
  133. string modelStateFilePath = Path.Combine(path, _modelStateFilename);
  134. Executor.Context.LoadState(modelStateFilePath);
  135. string executorStateFilepath = Path.Combine(path, _executorStateFilename);
  136. ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);
  137. string historyFilepath = Path.Combine(path, _hsitoryFilename);
  138. string historyJson = File.ReadAllText(historyFilepath);
  139. History = ChatHistory.FromJson(historyJson)
  140. ?? throw new ArgumentException("History file is invalid", nameof(path));
  141. }
  142. /// <summary>
  143. /// Add a message to the chat history.
  144. /// </summary>
  145. /// <param name="message"></param>
  146. /// <returns></returns>
  147. public ChatSession AddMessage(ChatHistory.Message message)
  148. {
  149. // If current message is a system message, only allow the history to be empty
  150. if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
  151. {
  152. throw new ArgumentException("Cannot add a system message after another message", nameof(message));
  153. }
  154. // If current message is a user message, only allow the history to be empty,
  155. // or the previous message to be a system message or assistant message.
  156. if (message.AuthorRole == AuthorRole.User)
  157. {
  158. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  159. if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
  160. {
  161. throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
  162. }
  163. }
  164. // If the current message is an assistant message,
  165. // the previous message must be a user message.
  166. if (message.AuthorRole == AuthorRole.Assistant)
  167. {
  168. ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
  169. if (lastMessage is null
  170. || lastMessage.AuthorRole != AuthorRole.User)
  171. {
  172. throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message));
  173. }
  174. }
  175. History.AddMessage(message.AuthorRole, message.Content);
  176. return this;
  177. }
  178. /// <summary>
  179. /// Add a system message to the chat history.
  180. /// </summary>
  181. /// <param name="content"></param>
  182. /// <returns></returns>
  183. public ChatSession AddSystemMessage(string content)
  184. => AddMessage(new ChatHistory.Message(AuthorRole.System, content));
  185. /// <summary>
  186. /// Add an assistant message to the chat history.
  187. /// </summary>
  188. /// <param name="content"></param>
  189. /// <returns></returns>
  190. public ChatSession AddAssistantMessage(string content)
  191. => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
  192. /// <summary>
  193. /// Add a user message to the chat history.
  194. /// </summary>
  195. /// <param name="content"></param>
  196. /// <returns></returns>
  197. public ChatSession AddUserMessage(string content)
  198. => AddMessage(new ChatHistory.Message(AuthorRole.User, content));
  199. /// <summary>
  200. /// Remove the last message from the chat history.
  201. /// </summary>
  202. /// <returns></returns>
  203. public ChatSession RemoveLastMessage()
  204. {
  205. History.Messages.RemoveAt(History.Messages.Count - 1);
  206. return this;
  207. }
  208. /// <summary>
  209. /// Replace a user message with a new message and remove all messages after the new message.
  210. /// This is useful when the user wants to edit a message. And regenerate the response.
  211. /// </summary>
  212. /// <param name="oldMessage"></param>
  213. /// <param name="newMessage"></param>
  214. /// <returns></returns>
  215. public ChatSession ReplaceUserMessage(
  216. ChatHistory.Message oldMessage,
  217. ChatHistory.Message newMessage)
  218. {
  219. if (oldMessage.AuthorRole != AuthorRole.User)
  220. {
  221. throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
  222. }
  223. if (newMessage.AuthorRole != AuthorRole.User)
  224. {
  225. throw new ArgumentException("New message must be a user message", nameof(newMessage));
  226. }
  227. int index = History.Messages.IndexOf(oldMessage);
  228. if (index == -1)
  229. {
  230. throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
  231. }
  232. History.Messages[index] = newMessage;
  233. // Remove all message after the new message
  234. History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
  235. return this;
  236. }
  237. /// <summary>
  238. /// Chat with the model.
  239. /// </summary>
  240. /// <param name="message"></param>
  241. /// <param name="inferenceParams"></param>
  242. /// <param name="applyInputTransformPipeline"></param>
  243. /// <param name="cancellationToken"></param>
  244. /// <returns></returns>
  245. /// <exception cref="ArgumentException"></exception>
  246. public async IAsyncEnumerable<string> ChatAsync(
  247. ChatHistory.Message message,
  248. bool applyInputTransformPipeline,
  249. IInferenceParams? inferenceParams = null,
  250. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  251. {
  252. // The message must be a user message
  253. if (message.AuthorRole != AuthorRole.User)
  254. {
  255. throw new ArgumentException("Message must be a user message", nameof(message));
  256. }
  257. // Apply input transform pipeline
  258. if (applyInputTransformPipeline)
  259. {
  260. foreach (var inputTransform in InputTransformPipeline)
  261. {
  262. message.Content = inputTransform.Transform(message.Content);
  263. }
  264. }
  265. // Add the user's message to the history
  266. AddUserMessage(message.Content);
  267. // Prepare prompt variable
  268. string prompt;
  269. // Check if the session history was restored from a previous session
  270. // or added as part of new chat session history.
  271. InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
  272. // If "IsPromptRun" is true, the session was newly started.
  273. if (state.IsPromptRun)
  274. {
  275. // If the session history was added as part of new chat session history,
  276. // convert the complete history includsing system message and manually added history
  277. // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
  278. prompt = HistoryTransform.HistoryToText(History);
  279. }
  280. else
  281. {
  282. // If the session was restored from a previous session,
  283. // convert only the current message to the prompt with the prompt template
  284. // specified in the HistoryTransform class implementation that is provided.
  285. ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
  286. prompt = HistoryTransform.HistoryToText(singleMessageHistory);
  287. }
  288. string assistantMessage = string.Empty;
  289. await foreach (
  290. string textToken
  291. in ChatAsyncInternal(
  292. prompt,
  293. inferenceParams,
  294. cancellationToken))
  295. {
  296. assistantMessage += textToken;
  297. yield return textToken;
  298. }
  299. // Add the assistant message to the history
  300. AddAssistantMessage(assistantMessage);
  301. }
  302. /// <summary>
  303. /// Chat with the model.
  304. /// </summary>
  305. /// <param name="message"></param>
  306. /// <param name="inferenceParams"></param>
  307. /// <param name="cancellationToken"></param>
  308. /// <returns></returns>
  309. public IAsyncEnumerable<string> ChatAsync(
  310. ChatHistory.Message message,
  311. IInferenceParams? inferenceParams = null,
  312. CancellationToken cancellationToken = default)
  313. {
  314. return ChatAsync(
  315. message,
  316. applyInputTransformPipeline: true,
  317. inferenceParams,
  318. cancellationToken);
  319. }
  320. /// <summary>
  321. /// Chat with the model.
  322. /// </summary>
  323. /// <param name="history"></param>
  324. /// <param name="applyInputTransformPipeline"></param>
  325. /// <param name="inferenceParams"></param>
  326. /// <param name="cancellationToken"></param>
  327. /// <returns></returns>
  328. /// <exception cref="ArgumentException"></exception>
  329. public IAsyncEnumerable<string> ChatAsync(
  330. ChatHistory history,
  331. bool applyInputTransformPipeline,
  332. IInferenceParams? inferenceParams = null,
  333. CancellationToken cancellationToken = default)
  334. {
  335. ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
  336. ?? throw new ArgumentException("History must contain at least one message", nameof(history));
  337. foreach (
  338. ChatHistory.Message message
  339. in history.Messages.Take(history.Messages.Count - 1))
  340. {
  341. // Apply input transform pipeline
  342. if (applyInputTransformPipeline
  343. && message.AuthorRole == AuthorRole.User)
  344. {
  345. foreach (
  346. var inputTransform
  347. in InputTransformPipeline)
  348. {
  349. message.Content = inputTransform.Transform(message.Content);
  350. }
  351. }
  352. AddMessage(message);
  353. }
  354. return ChatAsync(
  355. lastMessage,
  356. applyInputTransformPipeline,
  357. inferenceParams,
  358. cancellationToken);
  359. }
  360. /// <summary>
  361. /// Chat with the model.
  362. /// </summary>
  363. /// <param name="history"></param>
  364. /// <param name="inferenceParams"></param>
  365. /// <param name="cancellationToken"></param>
  366. /// <returns></returns>
  367. public IAsyncEnumerable<string> ChatAsync(
  368. ChatHistory history,
  369. IInferenceParams? inferenceParams = null,
  370. CancellationToken cancellationToken = default)
  371. {
  372. return ChatAsync(
  373. history,
  374. applyInputTransformPipeline: true,
  375. inferenceParams,
  376. cancellationToken);
  377. }
  378. /// <summary>
  379. /// Regenerate the last assistant message.
  380. /// </summary>
  381. /// <param name="inferenceParams"></param>
  382. /// <param name="cancellationToken"></param>
  383. /// <returns></returns>
  384. /// <exception cref="InvalidOperationException"></exception>
  385. public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
  386. InferenceParams? inferenceParams = null,
  387. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  388. {
  389. // Make sure the last message is an assistant message (reponse from the LLM).
  390. ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
  391. if (lastAssistantMessage is null
  392. || lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
  393. {
  394. throw new InvalidOperationException("Last message must be an assistant message");
  395. }
  396. // Remove the last assistant message from the history.
  397. RemoveLastMessage();
  398. // Get the last user message.
  399. ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
  400. if (lastUserMessage is null
  401. || lastUserMessage.AuthorRole != AuthorRole.User)
  402. {
  403. throw new InvalidOperationException("Last message must be a user message");
  404. }
  405. // Remove the last user message from the history.
  406. RemoveLastMessage();
  407. // Regenerate the assistant message.
  408. await foreach (
  409. string textToken
  410. in ChatAsync(
  411. lastUserMessage,
  412. applyInputTransformPipeline: false,
  413. inferenceParams,
  414. cancellationToken))
  415. {
  416. yield return textToken;
  417. }
  418. }
  419. private async IAsyncEnumerable<string> ChatAsyncInternal(
  420. string prompt,
  421. IInferenceParams? inferenceParams = null,
  422. [EnumeratorCancellation] CancellationToken cancellationToken = default)
  423. {
  424. var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
  425. await foreach (
  426. string textToken
  427. in OutputTransform
  428. .TransformAsync(results)
  429. .WithCancellation(cancellationToken))
  430. {
  431. yield return textToken;
  432. }
  433. }
  434. }