You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Runtime.CompilerServices;
  6. using System.Text;
  7. using System.Threading;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// The main chat session class.
  12. /// </summary>
  13. public class ChatSession
  14. {
  15. private ILLamaExecutor _executor;
  16. private ChatHistory _history;
  17. private static readonly string _executorStateFilename = "ExecutorState.json";
  18. private static readonly string _modelStateFilename = "ModelState.st";
  19. /// <summary>
  20. /// The executor for this session.
  21. /// </summary>
  22. public ILLamaExecutor Executor => _executor;
  23. /// <summary>
  24. /// The chat history for this session.
  25. /// </summary>
  26. public ChatHistory History => _history;
  27. /// <summary>
  28. /// The history transform used in this session.
  29. /// </summary>
  30. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  31. /// <summary>
  32. /// The input transform pipeline used in this session.
  33. /// </summary>
  34. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  35. /// <summary>
  36. /// The output transform used in this session.
  37. /// </summary>
  38. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  39. /// <summary>
  40. ///
  41. /// </summary>
  42. /// <param name="executor">The executor for this session</param>
  43. public ChatSession(ILLamaExecutor executor)
  44. {
  45. _executor = executor;
  46. _history = new ChatHistory();
  47. }
  48. /// <summary>
  49. /// Use a custom history transform.
  50. /// </summary>
  51. /// <param name="transform"></param>
  52. /// <returns></returns>
  53. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  54. {
  55. HistoryTransform = transform;
  56. return this;
  57. }
  58. /// <summary>
  59. /// Add a text transform to the input transform pipeline.
  60. /// </summary>
  61. /// <param name="transform"></param>
  62. /// <returns></returns>
  63. public ChatSession AddInputTransform(ITextTransform transform)
  64. {
  65. InputTransformPipeline.Add(transform);
  66. return this;
  67. }
  68. /// <summary>
  69. /// Use a custom output transform.
  70. /// </summary>
  71. /// <param name="transform"></param>
  72. /// <returns></returns>
  73. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  74. {
  75. OutputTransform = transform;
  76. return this;
  77. }
  78. /// <summary>
  79. ///
  80. /// </summary>
  81. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  82. public virtual void SaveSession(string path)
  83. {
  84. if (!Directory.Exists(path))
  85. {
  86. Directory.CreateDirectory(path);
  87. }
  88. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  89. if(Executor is StatelessExecutor)
  90. {
  91. }
  92. else if(Executor is StatefulExecutorBase statefulExecutor)
  93. {
  94. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  95. }
  96. else
  97. {
  98. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  99. }
  100. }
  101. /// <summary>
  102. ///
  103. /// </summary>
  104. /// <param name="path">The directory name to load the session.</param>
  105. public virtual void LoadSession(string path)
  106. {
  107. if (!Directory.Exists(path))
  108. {
  109. throw new FileNotFoundException($"Directory {path} does not exist.");
  110. }
  111. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  112. if (Executor is StatelessExecutor)
  113. {
  114. }
  115. else if (Executor is StatefulExecutorBase statefulExecutor)
  116. {
  117. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  118. }
  119. else
  120. {
  121. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  122. }
  123. }
  124. /// <summary>
  125. /// Get the response from the LLama model with chat histories.
  126. /// </summary>
  127. /// <param name="history"></param>
  128. /// <param name="inferenceParams"></param>
  129. /// <param name="cancellationToken"></param>
  130. /// <returns></returns>
  131. public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  132. {
  133. var prompt = HistoryTransform.HistoryToText(history);
  134. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  135. StringBuilder sb = new();
  136. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  137. {
  138. yield return result;
  139. sb.Append(result);
  140. }
  141. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  142. }
  143. /// <summary>
  144. /// Get the response from the LLama model. Note that prompt could not only be the preset words,
  145. /// but also the question you want to ask.
  146. /// </summary>
  147. /// <param name="prompt"></param>
  148. /// <param name="inferenceParams"></param>
  149. /// <param name="cancellationToken"></param>
  150. /// <returns></returns>
  151. public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  152. {
  153. foreach(var inputTransform in InputTransformPipeline)
  154. {
  155. prompt = inputTransform.Transform(prompt);
  156. }
  157. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  158. StringBuilder sb = new();
  159. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  160. {
  161. yield return result;
  162. sb.Append(result);
  163. }
  164. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  165. }
  166. /// <summary>
  167. /// Get the response from the LLama model with chat histories.
  168. /// </summary>
  169. /// <param name="history"></param>
  170. /// <param name="inferenceParams"></param>
  171. /// <param name="cancellationToken"></param>
  172. /// <returns></returns>
  173. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  174. {
  175. var prompt = HistoryTransform.HistoryToText(history);
  176. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  177. StringBuilder sb = new();
  178. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  179. {
  180. yield return result;
  181. sb.Append(result);
  182. }
  183. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  184. }
  185. /// <summary>
  186. /// Get the response from the LLama model with chat histories asynchronously.
  187. /// </summary>
  188. /// <param name="prompt"></param>
  189. /// <param name="inferenceParams"></param>
  190. /// <param name="cancellationToken"></param>
  191. /// <returns></returns>
  192. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  193. {
  194. foreach (var inputTransform in InputTransformPipeline)
  195. {
  196. prompt = inputTransform.Transform(prompt);
  197. }
  198. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  199. StringBuilder sb = new();
  200. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  201. {
  202. yield return result;
  203. sb.Append(result);
  204. }
  205. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  206. }
  207. private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  208. {
  209. var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
  210. return OutputTransform.Transform(results);
  211. }
  212. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  213. {
  214. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  215. await foreach (var item in OutputTransform.TransformAsync(results))
  216. {
  217. yield return item;
  218. }
  219. }
  220. }
  221. }