You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Runtime.CompilerServices;
  6. using System.Text;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace LLama
  10. {
  11. /// <summary>
  12. /// The main chat session class.
  13. /// </summary>
  14. public class ChatSession
  15. {
  16. private readonly ILLamaExecutor _executor;
  17. private readonly ChatHistory _history;
  18. private const string _executorStateFilename = "ExecutorState.json";
  19. private const string _modelStateFilename = "ModelState.st";
  20. /// <summary>
  21. /// The executor for this session.
  22. /// </summary>
  23. public ILLamaExecutor Executor => _executor;
  24. /// <summary>
  25. /// The chat history for this session.
  26. /// </summary>
  27. public ChatHistory History => _history;
  28. /// <summary>
  29. /// The history transform used in this session.
  30. /// </summary>
  31. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  32. /// <summary>
  33. /// The input transform pipeline used in this session.
  34. /// </summary>
  35. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  36. /// <summary>
  37. /// The output transform used in this session.
  38. /// </summary>
  39. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  40. /// <summary>
  41. ///
  42. /// </summary>
  43. /// <param name="executor">The executor for this session</param>
  44. public ChatSession(ILLamaExecutor executor)
  45. {
  46. _executor = executor;
  47. _history = new ChatHistory();
  48. }
  49. /// <summary>
  50. /// Use a custom history transform.
  51. /// </summary>
  52. /// <param name="transform"></param>
  53. /// <returns></returns>
  54. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  55. {
  56. HistoryTransform = transform;
  57. return this;
  58. }
  59. /// <summary>
  60. /// Add a text transform to the input transform pipeline.
  61. /// </summary>
  62. /// <param name="transform"></param>
  63. /// <returns></returns>
  64. public ChatSession AddInputTransform(ITextTransform transform)
  65. {
  66. InputTransformPipeline.Add(transform);
  67. return this;
  68. }
  69. /// <summary>
  70. /// Use a custom output transform.
  71. /// </summary>
  72. /// <param name="transform"></param>
  73. /// <returns></returns>
  74. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  75. {
  76. OutputTransform = transform;
  77. return this;
  78. }
  79. /// <summary>
  80. ///
  81. /// </summary>
  82. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  83. public virtual void SaveSession(string path)
  84. {
  85. if (!Directory.Exists(path))
  86. {
  87. Directory.CreateDirectory(path);
  88. }
  89. _executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
  90. if(Executor is StatelessExecutor)
  91. {
  92. }
  93. else if(Executor is StatefulExecutorBase statefulExecutor)
  94. {
  95. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  96. }
  97. else
  98. {
  99. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  100. }
  101. }
  102. /// <summary>
  103. ///
  104. /// </summary>
  105. /// <param name="path">The directory name to load the session.</param>
  106. public virtual void LoadSession(string path)
  107. {
  108. if (!Directory.Exists(path))
  109. {
  110. throw new FileNotFoundException($"Directory {path} does not exist.");
  111. }
  112. _executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
  113. if (Executor is StatelessExecutor)
  114. {
  115. }
  116. else if (Executor is StatefulExecutorBase statefulExecutor)
  117. {
  118. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  119. }
  120. else
  121. {
  122. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  123. }
  124. }
  125. /// <summary>
  126. /// Get the response from the LLama model. Note that prompt could not only be the preset words,
  127. /// but also the question you want to ask.
  128. /// </summary>
  129. /// <param name="prompt"></param>
  130. /// <param name="inferenceParams"></param>
  131. /// <param name="cancellationToken"></param>
  132. /// <returns></returns>
  133. public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  134. {
  135. foreach(var inputTransform in InputTransformPipeline)
  136. prompt = inputTransform.Transform(prompt);
  137. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  138. StringBuilder sb = new();
  139. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  140. {
  141. yield return result;
  142. sb.Append(result);
  143. }
  144. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  145. }
  146. /// <summary>
  147. /// Get the response from the LLama model with chat histories.
  148. /// </summary>
  149. /// <param name="history"></param>
  150. /// <param name="inferenceParams"></param>
  151. /// <param name="cancellationToken"></param>
  152. /// <returns></returns>
  153. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  154. {
  155. var prompt = HistoryTransform.HistoryToText(history);
  156. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  157. StringBuilder sb = new();
  158. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  159. {
  160. yield return result;
  161. sb.Append(result);
  162. }
  163. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  164. }
  165. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  166. {
  167. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  168. await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
  169. {
  170. yield return item;
  171. }
  172. }
  173. }
  174. }