You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Runtime.CompilerServices;
  8. using System.Text;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. using static LLama.InteractiveExecutor;
  12. namespace LLama
  13. {
  14. /// <summary>
  15. /// The main chat session class.
  16. /// </summary>
  17. public class ChatSession
  18. {
  19. private readonly ILLamaExecutor _executor;
  20. private readonly ChatHistory _history;
  21. private const string _executorStateFilename = "ExecutorState.json";
  22. private const string _modelStateFilename = "ModelState.st";
  23. /// <summary>
  24. /// The executor for this session.
  25. /// </summary>
  26. public ILLamaExecutor Executor => _executor;
  27. /// <summary>
  28. /// The chat history for this session.
  29. /// </summary>
  30. public ChatHistory History => _history;
  31. /// <summary>
  32. /// The history transform used in this session.
  33. /// </summary>
  34. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  35. /// <summary>
  36. /// The input transform pipeline used in this session.
  37. /// </summary>
  38. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  39. /// <summary>
  40. /// The output transform used in this session.
  41. /// </summary>
  42. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  43. /// <summary>
  44. ///
  45. /// </summary>
  46. /// <param name="executor">The executor for this session</param>
  47. public ChatSession(ILLamaExecutor executor)
  48. {
  49. _executor = executor;
  50. _history = new ChatHistory();
  51. }
  52. /// <summary>
  53. /// Use a custom history transform.
  54. /// </summary>
  55. /// <param name="transform"></param>
  56. /// <returns></returns>
  57. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  58. {
  59. HistoryTransform = transform;
  60. return this;
  61. }
  62. /// <summary>
  63. /// Add a text transform to the input transform pipeline.
  64. /// </summary>
  65. /// <param name="transform"></param>
  66. /// <returns></returns>
  67. public ChatSession AddInputTransform(ITextTransform transform)
  68. {
  69. InputTransformPipeline.Add(transform);
  70. return this;
  71. }
  72. /// <summary>
  73. /// Use a custom output transform.
  74. /// </summary>
  75. /// <param name="transform"></param>
  76. /// <returns></returns>
  77. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  78. {
  79. OutputTransform = transform;
  80. return this;
  81. }
  82. /// <summary>
  83. ///
  84. /// </summary>
  85. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  86. public virtual void SaveSession(string path)
  87. {
  88. if (!Directory.Exists(path))
  89. {
  90. Directory.CreateDirectory(path);
  91. }
  92. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  93. if (Executor is StatelessExecutor)
  94. {
  95. }
  96. else if (Executor is StatefulExecutorBase statefulExecutor)
  97. {
  98. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  99. }
  100. else
  101. {
  102. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  103. }
  104. }
  105. /// <summary>
  106. ///
  107. /// </summary>
  108. /// <param name="path">The directory name to load the session.</param>
  109. public virtual void LoadSession(string path)
  110. {
  111. if (!Directory.Exists(path))
  112. {
  113. throw new FileNotFoundException($"Directory {path} does not exist.");
  114. }
  115. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  116. if (Executor is StatelessExecutor)
  117. {
  118. }
  119. else if (Executor is StatefulExecutorBase statefulExecutor)
  120. {
  121. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  122. }
  123. else
  124. {
  125. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  126. }
  127. }
  128. /// <summary>
  129. /// Generates a response for a given user prompt and manages history state for the user.
  130. /// This will always pass the whole history to the model. Don't pass a whole history
  131. /// to this method as the user prompt will be appended to the history of the current session.
  132. /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
  133. /// </summary>
  134. /// <param name="prompt"></param>
  135. /// <param name="inferenceParams"></param>
  136. /// <param name="cancellationToken"></param>
  137. /// <returns>Returns generated text of the assistant message.</returns>
  138. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  139. {
  140. foreach (var inputTransform in InputTransformPipeline)
  141. prompt = inputTransform.Transform(prompt);
  142. History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
  143. if (_executor is InteractiveExecutor executor)
  144. {
  145. InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
  146. prompt = state.IsPromptRun
  147. ? HistoryTransform.HistoryToText(History)
  148. : prompt;
  149. }
  150. StringBuilder sb = new();
  151. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  152. {
  153. yield return result;
  154. sb.Append(result);
  155. }
  156. string assistantMessage = sb.ToString();
  157. // Remove end tokens from the assistant message
  158. // if defined in inferenceParams.AntiPrompts.
  159. // We only want the response that was generated and not tokens
  160. // that are delimiting the beginning or end of the response.
  161. if (inferenceParams?.AntiPrompts != null)
  162. {
  163. foreach (var stopToken in inferenceParams.AntiPrompts)
  164. {
  165. assistantMessage = assistantMessage.Replace(stopToken, "");
  166. }
  167. }
  168. History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
  169. }
  170. /// <summary>
  171. /// Generates a response for a given chat history. This method does not manage history state for the user.
  172. /// If you want to e.g. truncate the history of a session to fit into the model's context window,
  173. /// use this method and pass the truncated history to it. If you don't need this control, use the other
  174. /// overload of this method that accepts a user prompt instead.
  175. /// </summary>
  176. /// <param name="history"></param>
  177. /// <param name="inferenceParams"></param>
  178. /// <param name="cancellationToken"></param>
  179. /// <returns>Returns generated text of the assistant message.</returns>
  180. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  181. {
  182. if (history.Messages.Count == 0)
  183. {
  184. throw new ArgumentException("History must contain at least one message.");
  185. }
  186. string prompt;
  187. if (_executor is InteractiveExecutor executor)
  188. {
  189. InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
  190. prompt = state.IsPromptRun
  191. ? HistoryTransform.HistoryToText(History)
  192. : history.Messages.Last().Content;
  193. }
  194. else
  195. {
  196. prompt = history.Messages.Last().Content;
  197. }
  198. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  199. {
  200. yield return result;
  201. }
  202. }
  203. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  204. {
  205. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  206. await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
  207. {
  208. yield return item;
  209. }
  210. }
  211. }
  212. }