You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Runtime.CompilerServices;
  6. using System.Text;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace LLama
  10. {
  11. /// <summary>
  12. /// The main chat session class.
  13. /// </summary>
  14. public class ChatSession
  15. {
  16. private readonly ILLamaExecutor _executor;
  17. private readonly ChatHistory _history;
  18. private const string _executorStateFilename = "ExecutorState.json";
  19. private const string _modelStateFilename = "ModelState.st";
  20. /// <summary>
  21. /// The executor for this session.
  22. /// </summary>
  23. public ILLamaExecutor Executor => _executor;
  24. /// <summary>
  25. /// The chat history for this session.
  26. /// </summary>
  27. public ChatHistory History => _history;
  28. /// <summary>
  29. /// The history transform used in this session.
  30. /// </summary>
  31. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  32. /// <summary>
  33. /// The input transform pipeline used in this session.
  34. /// </summary>
  35. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  36. /// <summary>
  37. /// The output transform used in this session.
  38. /// </summary>
  39. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  40. /// <summary>
  41. ///
  42. /// </summary>
  43. /// <param name="executor">The executor for this session</param>
  44. public ChatSession(ILLamaExecutor executor)
  45. {
  46. _executor = executor;
  47. _history = new ChatHistory();
  48. }
  49. /// <summary>
  50. /// Use a custom history transform.
  51. /// </summary>
  52. /// <param name="transform"></param>
  53. /// <returns></returns>
  54. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  55. {
  56. HistoryTransform = transform;
  57. return this;
  58. }
  59. /// <summary>
  60. /// Add a text transform to the input transform pipeline.
  61. /// </summary>
  62. /// <param name="transform"></param>
  63. /// <returns></returns>
  64. public ChatSession AddInputTransform(ITextTransform transform)
  65. {
  66. InputTransformPipeline.Add(transform);
  67. return this;
  68. }
  69. /// <summary>
  70. /// Use a custom output transform.
  71. /// </summary>
  72. /// <param name="transform"></param>
  73. /// <returns></returns>
  74. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  75. {
  76. OutputTransform = transform;
  77. return this;
  78. }
  79. /// <summary>
  80. ///
  81. /// </summary>
  82. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  83. public virtual void SaveSession(string path)
  84. {
  85. if (!Directory.Exists(path))
  86. {
  87. Directory.CreateDirectory(path);
  88. }
  89. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  90. if(Executor is StatelessExecutor)
  91. {
  92. }
  93. else if(Executor is StatefulExecutorBase statefulExecutor)
  94. {
  95. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  96. }
  97. else
  98. {
  99. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  100. }
  101. }
  102. /// <summary>
  103. ///
  104. /// </summary>
  105. /// <param name="path">The directory name to load the session.</param>
  106. public virtual void LoadSession(string path)
  107. {
  108. if (!Directory.Exists(path))
  109. {
  110. throw new FileNotFoundException($"Directory {path} does not exist.");
  111. }
  112. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  113. if (Executor is StatelessExecutor)
  114. {
  115. }
  116. else if (Executor is StatefulExecutorBase statefulExecutor)
  117. {
  118. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  119. }
  120. else
  121. {
  122. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  123. }
  124. }
  125. /// <summary>
  126. /// Get the response from the LLama model with chat histories.
  127. /// </summary>
  128. /// <param name="history"></param>
  129. /// <param name="inferenceParams"></param>
  130. /// <param name="cancellationToken"></param>
  131. /// <returns></returns>
  132. public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  133. {
  134. var prompt = HistoryTransform.HistoryToText(history);
  135. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  136. StringBuilder sb = new();
  137. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  138. {
  139. yield return result;
  140. sb.Append(result);
  141. }
  142. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  143. }
  144. /// <summary>
  145. /// Get the response from the LLama model. Note that prompt could not only be the preset words,
  146. /// but also the question you want to ask.
  147. /// </summary>
  148. /// <param name="prompt"></param>
  149. /// <param name="inferenceParams"></param>
  150. /// <param name="cancellationToken"></param>
  151. /// <returns></returns>
  152. public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  153. {
  154. foreach(var inputTransform in InputTransformPipeline)
  155. {
  156. prompt = inputTransform.Transform(prompt);
  157. }
  158. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  159. StringBuilder sb = new();
  160. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  161. {
  162. yield return result;
  163. sb.Append(result);
  164. }
  165. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  166. }
  167. /// <summary>
  168. /// Get the response from the LLama model with chat histories.
  169. /// </summary>
  170. /// <param name="history"></param>
  171. /// <param name="inferenceParams"></param>
  172. /// <param name="cancellationToken"></param>
  173. /// <returns></returns>
  174. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  175. {
  176. var prompt = HistoryTransform.HistoryToText(history);
  177. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  178. StringBuilder sb = new();
  179. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  180. {
  181. yield return result;
  182. sb.Append(result);
  183. }
  184. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  185. }
  186. /// <summary>
  187. /// Get the response from the LLama model with chat histories asynchronously.
  188. /// </summary>
  189. /// <param name="prompt"></param>
  190. /// <param name="inferenceParams"></param>
  191. /// <param name="cancellationToken"></param>
  192. /// <returns></returns>
  193. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  194. {
  195. foreach (var inputTransform in InputTransformPipeline)
  196. {
  197. prompt = inputTransform.Transform(prompt);
  198. }
  199. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  200. StringBuilder sb = new();
  201. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  202. {
  203. yield return result;
  204. sb.Append(result);
  205. }
  206. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  207. }
  208. private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  209. {
  210. var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
  211. return OutputTransform.Transform(results);
  212. }
  213. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  214. {
  215. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  216. await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
  217. {
  218. yield return item;
  219. }
  220. }
  221. }
  222. }