You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Runtime.CompilerServices;
  6. using System.Text;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace LLama
  10. {
  11. /// <summary>
  12. /// The main chat session class.
  13. /// </summary>
  14. public class ChatSession
  15. {
  16. private readonly ILLamaExecutor _executor;
  17. private readonly ChatHistory _history;
  18. private const string _executorStateFilename = "ExecutorState.json";
  19. private const string _modelStateFilename = "ModelState.st";
  20. /// <summary>
  21. /// The executor for this session.
  22. /// </summary>
  23. public ILLamaExecutor Executor => _executor;
  24. /// <summary>
  25. /// The chat history for this session.
  26. /// </summary>
  27. public ChatHistory History => _history;
  28. /// <summary>
  29. /// The history transform used in this session.
  30. /// </summary>
  31. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  32. /// <summary>
  33. /// The input transform pipeline used in this session.
  34. /// </summary>
  35. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  36. /// <summary>
  37. /// The output transform used in this session.
  38. /// </summary>
  39. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  40. /// <summary>
  41. ///
  42. /// </summary>
  43. /// <param name="executor">The executor for this session</param>
  44. public ChatSession(ILLamaExecutor executor)
  45. {
  46. _executor = executor;
  47. _history = new ChatHistory();
  48. }
  49. /// <summary>
  50. /// Use a custom history transform.
  51. /// </summary>
  52. /// <param name="transform"></param>
  53. /// <returns></returns>
  54. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  55. {
  56. HistoryTransform = transform;
  57. return this;
  58. }
  59. /// <summary>
  60. /// Add a text transform to the input transform pipeline.
  61. /// </summary>
  62. /// <param name="transform"></param>
  63. /// <returns></returns>
  64. public ChatSession AddInputTransform(ITextTransform transform)
  65. {
  66. InputTransformPipeline.Add(transform);
  67. return this;
  68. }
  69. /// <summary>
  70. /// Use a custom output transform.
  71. /// </summary>
  72. /// <param name="transform"></param>
  73. /// <returns></returns>
  74. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  75. {
  76. OutputTransform = transform;
  77. return this;
  78. }
  79. /// <summary>
  80. ///
  81. /// </summary>
  82. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  83. public virtual void SaveSession(string path)
  84. {
  85. if (!Directory.Exists(path))
  86. {
  87. Directory.CreateDirectory(path);
  88. }
  89. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  90. if (Executor is StatelessExecutor)
  91. {
  92. }
  93. else if (Executor is StatefulExecutorBase statefulExecutor)
  94. {
  95. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  96. }
  97. else
  98. {
  99. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  100. }
  101. }
  102. /// <summary>
  103. ///
  104. /// </summary>
  105. /// <param name="path">The directory name to load the session.</param>
  106. public virtual void LoadSession(string path)
  107. {
  108. if (!Directory.Exists(path))
  109. {
  110. throw new FileNotFoundException($"Directory {path} does not exist.");
  111. }
  112. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  113. if (Executor is StatelessExecutor)
  114. {
  115. }
  116. else if (Executor is StatefulExecutorBase statefulExecutor)
  117. {
  118. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  119. }
  120. else
  121. {
  122. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  123. }
  124. }
  125. /// <summary>
  126. /// Generates a response for a given user prompt and manages history state for the user.
  127. /// This will always pass the whole history to the model. Don't pass a whole history
  128. /// to this method as the user prompt will be appended to the history of the current session.
  129. /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
  130. /// </summary>
  131. /// <param name="prompt"></param>
  132. /// <param name="inferenceParams"></param>
  133. /// <param name="cancellationToken"></param>
  134. /// <returns>Returns generated tokens of the assistant message.</returns>
  135. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  136. {
  137. foreach (var inputTransform in InputTransformPipeline)
  138. prompt = inputTransform.Transform(prompt);
  139. History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
  140. string internalPrompt = HistoryTransform.HistoryToText(History);
  141. StringBuilder sb = new();
  142. await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
  143. {
  144. yield return result;
  145. sb.Append(result);
  146. }
  147. string assistantMessage = sb.ToString();
  148. // Remove end tokens from the assistant message
  149. // if defined in inferenceParams.AntiPrompts.
  150. // We only want the response that was generated and not tokens
  151. // that are delimiting the beginning or end of the response.
  152. if (inferenceParams?.AntiPrompts != null)
  153. {
  154. foreach (var stopToken in inferenceParams.AntiPrompts)
  155. {
  156. assistantMessage = assistantMessage.Replace(stopToken, "");
  157. }
  158. }
  159. History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
  160. }
  161. /// <summary>
  162. /// Generates a response for a given chat history. This method does not manage history state for the user.
  163. /// If you want to e.g. truncate the history of a session to fit into the model's context window,
  164. /// use this method and pass the truncated history to it. If you don't need this control, use the other
  165. /// overload of this method that accepts a user prompt instead.
  166. /// </summary>
  167. /// <param name="history"></param>
  168. /// <param name="inferenceParams"></param>
  169. /// <param name="cancellationToken"></param>
  170. /// <returns></returns>
  171. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  172. {
  173. var prompt = HistoryTransform.HistoryToText(history);
  174. StringBuilder sb = new();
  175. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  176. {
  177. yield return result;
  178. sb.Append(result);
  179. }
  180. }
  181. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  182. {
  183. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  184. await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
  185. {
  186. yield return item;
  187. }
  188. }
  189. }
  190. }