You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Runtime.CompilerServices;
  6. using System.Text;
  7. using System.Threading;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// The main chat session class.
  12. /// </summary>
  13. public class ChatSession
  14. {
  15. private ILLamaExecutor _executor;
  16. private ChatHistory _history;
  17. private static readonly string _executorStateFilename = "ExecutorState.json";
  18. private static readonly string _modelStateFilename = "ModelState.st";
  19. /// <summary>
  20. /// The executor for this session.
  21. /// </summary>
  22. public ILLamaExecutor Executor => _executor;
  23. /// <summary>
  24. /// The chat history for this session.
  25. /// </summary>
  26. public ChatHistory History => _history;
  27. /// <summary>
  28. /// The history transform used in this session.
  29. /// </summary>
  30. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  31. /// <summary>
  32. /// The input transform pipeline used in this session.
  33. /// </summary>
  34. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  35. /// <summary>
  36. /// The output transform used in this session.
  37. /// </summary>
  38. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  39. /// <summary>
  40. ///
  41. /// </summary>
  42. /// <param name="executor">The executor for this session</param>
  43. public ChatSession(ILLamaExecutor executor)
  44. {
  45. _executor = executor;
  46. _history = new ChatHistory();
  47. }
  48. /// <summary>
  49. /// Use a custom history transform.
  50. /// </summary>
  51. /// <param name="transform"></param>
  52. /// <returns></returns>
  53. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  54. {
  55. HistoryTransform = transform;
  56. return this;
  57. }
  58. /// <summary>
  59. /// Add a text transform to the input transform pipeline.
  60. /// </summary>
  61. /// <param name="transform"></param>
  62. /// <returns></returns>
  63. public ChatSession AddInputTransform(ITextTransform transform)
  64. {
  65. InputTransformPipeline.Add(transform);
  66. return this;
  67. }
  68. /// <summary>
  69. /// Use a custom output transform.
  70. /// </summary>
  71. /// <param name="transform"></param>
  72. /// <returns></returns>
  73. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  74. {
  75. OutputTransform = transform;
  76. return this;
  77. }
  78. /// <summary>
  79. ///
  80. /// </summary>
  81. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  82. public virtual void SaveSession(string path)
  83. {
  84. if (!Directory.Exists(path))
  85. {
  86. Directory.CreateDirectory(path);
  87. }
  88. _executor.Model.SaveState(Path.Combine(path, _modelStateFilename));
  89. if(Executor is StatelessExecutor)
  90. {
  91. }
  92. else if(Executor is StatefulExecutorBase statefulExecutor)
  93. {
  94. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  95. }
  96. else
  97. {
  98. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  99. }
  100. }
  101. /// <summary>
  102. ///
  103. /// </summary>
  104. /// <param name="path">The directory name to load the session.</param>
  105. public virtual void LoadSession(string path)
  106. {
  107. if (!Directory.Exists(path))
  108. {
  109. throw new FileNotFoundException($"Directory {path} does not exist.");
  110. }
  111. _executor.Model.LoadState(Path.Combine(path, _modelStateFilename));
  112. if (Executor is StatelessExecutor)
  113. {
  114. }
  115. else if (Executor is StatefulExecutorBase statefulExecutor)
  116. {
  117. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  118. }
  119. else
  120. {
  121. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  122. }
  123. }
  124. /// <summary>
  125. /// Get the response from the LLama model with chat histories.
  126. /// </summary>
  127. /// <param name="prompt"></param>
  128. /// <param name="inferenceParams"></param>
  129. /// <returns></returns>
  130. public IEnumerable<string> Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  131. {
  132. var prompt = HistoryTransform.HistoryToText(history);
  133. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  134. StringBuilder sb = new();
  135. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  136. {
  137. yield return result;
  138. sb.Append(result);
  139. }
  140. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  141. }
  142. /// <summary>
  143. /// Get the response from the LLama model. Note that prompt could not only be the preset words,
  144. /// but also the question you want to ask.
  145. /// </summary>
  146. /// <param name="prompt"></param>
  147. /// <param name="inferenceParams"></param>
  148. /// <returns></returns>
  149. public IEnumerable<string> Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  150. {
  151. foreach(var inputTransform in InputTransformPipeline)
  152. {
  153. prompt = inputTransform.Transform(prompt);
  154. }
  155. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  156. StringBuilder sb = new();
  157. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  158. {
  159. yield return result;
  160. sb.Append(result);
  161. }
  162. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  163. }
  164. /// <summary>
  165. /// Get the response from the LLama model with chat histories.
  166. /// </summary>
  167. /// <param name="prompt"></param>
  168. /// <param name="inferenceParams"></param>
  169. /// <returns></returns>
  170. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  171. {
  172. var prompt = HistoryTransform.HistoryToText(history);
  173. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  174. StringBuilder sb = new();
  175. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  176. {
  177. yield return result;
  178. sb.Append(result);
  179. }
  180. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  181. }
  182. /// <summary>
  183. /// Get the response from the LLama model with chat histories asynchronously.
  184. /// </summary>
  185. /// <param name="prompt"></param>
  186. /// <param name="inferenceParams"></param>
  187. /// <param name="cancellationToken"></param>
  188. /// <returns></returns>
  189. public async IAsyncEnumerable<string> ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  190. {
  191. foreach (var inputTransform in InputTransformPipeline)
  192. {
  193. prompt = inputTransform.Transform(prompt);
  194. }
  195. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  196. StringBuilder sb = new();
  197. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  198. {
  199. yield return result;
  200. sb.Append(result);
  201. }
  202. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  203. }
  204. private IEnumerable<string> ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  205. {
  206. var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
  207. return OutputTransform.Transform(results);
  208. }
  209. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  210. {
  211. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  212. await foreach (var item in OutputTransform.TransformAsync(results))
  213. {
  214. yield return item;
  215. }
  216. }
  217. }
  218. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。