You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Runtime.CompilerServices;
  8. using System.Text;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. using static LLama.InteractiveExecutor;
  12. namespace LLama
  13. {
  14. /// <summary>
  15. /// The main chat session class.
  16. /// </summary>
  17. public class ChatSession
  18. {
  19. private readonly ILLamaExecutor _executor;
  20. private readonly ChatHistory _history;
  21. private const string _executorStateFilename = "ExecutorState.json";
  22. private const string _modelStateFilename = "ModelState.st";
  23. /// <summary>
  24. /// The executor for this session.
  25. /// </summary>
  26. public ILLamaExecutor Executor => _executor;
  27. /// <summary>
  28. /// The chat history for this session.
  29. /// </summary>
  30. public ChatHistory History => _history;
  31. /// <summary>
  32. /// The history transform used in this session.
  33. /// </summary>
  34. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  35. /// <summary>
  36. /// The input transform pipeline used in this session.
  37. /// </summary>
  38. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  39. /// <summary>
  40. /// The output transform used in this session.
  41. /// </summary>
  42. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  43. /// <summary>
  44. ///
  45. /// </summary>
  46. /// <param name="executor">The executor for this session</param>
  47. public ChatSession(ILLamaExecutor executor)
  48. {
  49. _executor = executor;
  50. _history = new ChatHistory();
  51. }
  52. /// <summary>
  53. /// Use a custom history transform.
  54. /// </summary>
  55. /// <param name="transform"></param>
  56. /// <returns></returns>
  57. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  58. {
  59. HistoryTransform = transform;
  60. return this;
  61. }
  62. /// <summary>
  63. /// Add a text transform to the input transform pipeline.
  64. /// </summary>
  65. /// <param name="transform"></param>
  66. /// <returns></returns>
  67. public ChatSession AddInputTransform(ITextTransform transform)
  68. {
  69. InputTransformPipeline.Add(transform);
  70. return this;
  71. }
  72. /// <summary>
  73. /// Use a custom output transform.
  74. /// </summary>
  75. /// <param name="transform"></param>
  76. /// <returns></returns>
  77. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  78. {
  79. OutputTransform = transform;
  80. return this;
  81. }
  82. /// <summary>
  83. ///
  84. /// </summary>
  85. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  86. public virtual void SaveSession(string path)
  87. {
  88. if (!Directory.Exists(path))
  89. {
  90. Directory.CreateDirectory(path);
  91. }
  92. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  93. if (Executor is StatelessExecutor)
  94. {
  95. }
  96. else if (Executor is StatefulExecutorBase statefulExecutor)
  97. {
  98. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  99. }
  100. else
  101. {
  102. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  103. }
  104. }
  105. /// <summary>
  106. ///
  107. /// </summary>
  108. /// <param name="path">The directory name to load the session.</param>
  109. public virtual void LoadSession(string path)
  110. {
  111. if (!Directory.Exists(path))
  112. {
  113. throw new FileNotFoundException($"Directory {path} does not exist.");
  114. }
  115. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  116. if (Executor is StatelessExecutor)
  117. {
  118. }
  119. else if (Executor is StatefulExecutorBase statefulExecutor)
  120. {
  121. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  122. }
  123. else
  124. {
  125. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  126. }
  127. }
  128. /// <summary>
  129. /// Generates a response for a given user prompt and manages history state for the user.
  130. /// This will always pass the whole history to the model. Don't pass a whole history
  131. /// to this method as the user prompt will be appended to the history of the current session.
  132. /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
  133. /// </summary>
  134. /// <param name="prompt"></param>
  135. /// <param name="inferenceParams"></param>
  136. /// <param name="cancellationToken"></param>
  137. /// <returns>Returns generated text of the assistant message.</returns>
  138. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  139. {
  140. foreach (var inputTransform in InputTransformPipeline)
  141. prompt = inputTransform.Transform(prompt);
  142. // TODO: need to be refactored.
  143. if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun)
  144. {
  145. History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt));
  146. var converted_prompt = HistoryTransform.HistoryToText(History);
  147. // Avoid missing anti-prompt.
  148. if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n"))
  149. {
  150. prompt = converted_prompt.Trim();
  151. }
  152. else
  153. {
  154. prompt = converted_prompt;
  155. }
  156. }
  157. else
  158. {
  159. History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
  160. }
  161. StringBuilder sb = new();
  162. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  163. {
  164. yield return result;
  165. sb.Append(result);
  166. }
  167. string assistantMessage = sb.ToString();
  168. // Remove end tokens from the assistant message
  169. // if defined in inferenceParams.AntiPrompts.
  170. // We only want the response that was generated and not tokens
  171. // that are delimiting the beginning or end of the response.
  172. if (inferenceParams?.AntiPrompts != null)
  173. {
  174. foreach (var stopToken in inferenceParams.AntiPrompts)
  175. {
  176. assistantMessage = assistantMessage.Replace(stopToken, "");
  177. }
  178. }
  179. History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
  180. }
  181. /// <summary>
  182. /// Generates a response for a given chat history. This method does not manage history state for the user.
  183. /// If you want to e.g. truncate the history of a session to fit into the model's context window,
  184. /// use this method and pass the truncated history to it. If you don't need this control, use the other
  185. /// overload of this method that accepts a user prompt instead.
  186. /// </summary>
  187. /// <param name="history"></param>
  188. /// <param name="inferenceParams"></param>
  189. /// <param name="cancellationToken"></param>
  190. /// <returns>Returns generated text of the assistant message.</returns>
  191. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  192. {
  193. if (history.Messages.Count == 0)
  194. {
  195. throw new ArgumentException("History must contain at least one message.");
  196. }
  197. string prompt;
  198. if (_executor is InteractiveExecutor executor)
  199. {
  200. InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
  201. prompt = state.IsPromptRun
  202. ? HistoryTransform.HistoryToText(History)
  203. : history.Messages.Last().Content;
  204. }
  205. else
  206. {
  207. prompt = history.Messages.Last().Content;
  208. }
  209. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  210. {
  211. yield return result;
  212. }
  213. }
  214. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  215. {
  216. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  217. await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
  218. {
  219. yield return item;
  220. }
  221. }
  222. }
  223. }