You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChatSession.cs 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. using LLama.Common;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Runtime.CompilerServices;
  5. using System.Text;
  6. using System.Threading;
  7. namespace LLama
  8. {
  9. public class ChatSession
  10. {
  11. private ILLamaExecutor _executor;
  12. private ChatHistory _history;
  13. private static readonly string _executorStateFilename = "ExecutorState.json";
  14. private static readonly string _modelStateFilename = "ModelState.st";
  15. public ILLamaExecutor Executor => _executor;
  16. public ChatHistory History => _history;
  17. public SessionParams Params { get; set; }
  18. public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
  19. public List<ITextTransform> InputTransformPipeline { get; set; } = new();
  20. public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
  21. public ChatSession(ILLamaExecutor executor, SessionParams? sessionParams = null)
  22. {
  23. _executor = executor;
  24. _history = new ChatHistory();
  25. Params = sessionParams ?? new SessionParams();
  26. }
  27. public ChatSession WithHistoryTransform(IHistoryTransform transform)
  28. {
  29. HistoryTransform = transform;
  30. return this;
  31. }
  32. public ChatSession AddInputTransform(ITextTransform transform)
  33. {
  34. InputTransformPipeline.Add(transform);
  35. return this;
  36. }
  37. public ChatSession WithOutputTransform(ITextStreamTransform transform)
  38. {
  39. OutputTransform = transform;
  40. return this;
  41. }
  42. /// <summary>
  43. ///
  44. /// </summary>
  45. /// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
  46. public virtual void SaveSession(string path)
  47. {
  48. if (!Directory.Exists(path))
  49. {
  50. Directory.CreateDirectory(path);
  51. }
  52. _executor.Model.SaveState(Path.Combine(path, _modelStateFilename));
  53. if(Executor is StatelessExecutor)
  54. {
  55. }
  56. else if(Executor is StatefulExecutorBase statefulExecutor)
  57. {
  58. statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
  59. }
  60. else
  61. {
  62. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  63. }
  64. }
  65. /// <summary>
  66. ///
  67. /// </summary>
  68. /// <param name="path">The directory name to load the session.</param>
  69. public virtual void LoadSession(string path)
  70. {
  71. if (!Directory.Exists(path))
  72. {
  73. throw new FileNotFoundException($"Directory {path} does not exist.");
  74. }
  75. _executor.Model.LoadState(Path.Combine(path, _modelStateFilename));
  76. if (Executor is StatelessExecutor)
  77. {
  78. }
  79. else if (Executor is StatefulExecutorBase statefulExecutor)
  80. {
  81. statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
  82. }
  83. else
  84. {
  85. throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
  86. }
  87. }
  88. /// <summary>
  89. /// Get the response from the LLama model with chat histories.
  90. /// </summary>
  91. /// <param name="prompt"></param>
  92. /// <param name="inferenceParams"></param>
  93. /// <returns></returns>
  94. public IEnumerable<string> Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  95. {
  96. var prompt = HistoryTransform.HistoryToText(history);
  97. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  98. StringBuilder sb = new();
  99. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  100. {
  101. yield return result;
  102. sb.Append(result);
  103. }
  104. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  105. }
  106. /// <summary>
  107. /// Get the response from the LLama model. Note that prompt could not only be the preset words,
  108. /// but also the question you want to ask.
  109. /// </summary>
  110. /// <param name="prompt"></param>
  111. /// <param name="inferenceParams"></param>
  112. /// <returns></returns>
  113. public IEnumerable<string> Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  114. {
  115. foreach(var inputTransform in InputTransformPipeline)
  116. {
  117. prompt = inputTransform.Transform(prompt);
  118. }
  119. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  120. StringBuilder sb = new();
  121. foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
  122. {
  123. yield return result;
  124. sb.Append(result);
  125. }
  126. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  127. }
  128. /// <summary>
  129. /// Get the response from the LLama model with chat histories.
  130. /// </summary>
  131. /// <param name="prompt"></param>
  132. /// <param name="inferenceParams"></param>
  133. /// <returns></returns>
  134. public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  135. {
  136. var prompt = HistoryTransform.HistoryToText(history);
  137. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  138. StringBuilder sb = new();
  139. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  140. {
  141. yield return result;
  142. sb.Append(result);
  143. }
  144. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  145. }
  146. public async IAsyncEnumerable<string> ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  147. {
  148. foreach (var inputTransform in InputTransformPipeline)
  149. {
  150. prompt = inputTransform.Transform(prompt);
  151. }
  152. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
  153. StringBuilder sb = new();
  154. await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
  155. {
  156. yield return result;
  157. sb.Append(result);
  158. }
  159. History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
  160. }
  161. private IEnumerable<string> ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
  162. {
  163. var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
  164. return OutputTransform.Transform(results);
  165. }
  166. private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  167. {
  168. var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
  169. await foreach (var item in OutputTransform.TransformAsync(results))
  170. {
  171. yield return item;
  172. }
  173. }
  174. }
  175. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。