You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTransforms.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. namespace LLama
  7. {
  8. /// <summary>
  9. /// A class that contains all the transforms provided internally by LLama.
  10. /// </summary>
  11. public class LLamaTransforms
  12. {
  13. /// <summary>
  14. /// The default history transform.
  15. /// Uses plain text with the following format:
  16. /// [Author]: [Message]
  17. /// </summary>
  18. public class DefaultHistoryTransform : IHistoryTransform
  19. {
  20. private readonly string defaultUserName = "User";
  21. private readonly string defaultAssistantName = "Assistant";
  22. private readonly string defaultSystemName = "System";
  23. private readonly string defaultUnknownName = "??";
  24. string _userName;
  25. string _assistantName;
  26. string _systemName;
  27. string _unknownName;
  28. bool _isInstructMode;
  29. /// <summary>
  30. ///
  31. /// </summary>
  32. /// <param name="userName"></param>
  33. /// <param name="assistantName"></param>
  34. /// <param name="systemName"></param>
  35. /// <param name="unknownName"></param>
  36. /// <param name="isInstructMode"></param>
  37. public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
  38. string? systemName = null, string? unknownName = null, bool isInstructMode = false)
  39. {
  40. _userName = userName ?? defaultUserName;
  41. _assistantName = assistantName ?? defaultAssistantName;
  42. _systemName = systemName ?? defaultSystemName;
  43. _unknownName = unknownName ?? defaultUnknownName;
  44. _isInstructMode = isInstructMode;
  45. }
  46. /// <inheritdoc />
  47. public virtual string HistoryToText(ChatHistory history)
  48. {
  49. StringBuilder sb = new();
  50. foreach (var message in history.Messages)
  51. {
  52. if (message.AuthorRole == AuthorRole.User)
  53. {
  54. sb.AppendLine($"{_userName}: {message.Content}");
  55. }
  56. else if (message.AuthorRole == AuthorRole.System)
  57. {
  58. sb.AppendLine($"{_systemName}: {message.Content}");
  59. }
  60. else if (message.AuthorRole == AuthorRole.Unknown)
  61. {
  62. sb.AppendLine($"{_unknownName}: {message.Content}");
  63. }
  64. else if (message.AuthorRole == AuthorRole.Assistant)
  65. {
  66. sb.AppendLine($"{_assistantName}: {message.Content}");
  67. }
  68. }
  69. return sb.ToString();
  70. }
  71. /// <inheritdoc />
  72. public virtual ChatHistory TextToHistory(AuthorRole role, string text)
  73. {
  74. ChatHistory history = new ChatHistory();
  75. history.AddMessage(role, TrimNamesFromText(text, role));
  76. return history;
  77. }
  78. /// <summary>
  79. /// Drop the name at the beginning and the end of the text.
  80. /// </summary>
  81. /// <param name="text"></param>
  82. /// <param name="role"></param>
  83. /// <returns></returns>
  84. public virtual string TrimNamesFromText(string text, AuthorRole role)
  85. {
  86. if (role == AuthorRole.User && text.StartsWith($"{_userName}:"))
  87. {
  88. text = text.Substring($"{_userName}:".Length).TrimStart();
  89. }
  90. else if (role == AuthorRole.Assistant && text.EndsWith($"{_assistantName}:"))
  91. {
  92. text = text.Substring(0, text.Length - $"{_assistantName}:".Length).TrimEnd();
  93. }
  94. if (_isInstructMode && role == AuthorRole.Assistant && text.EndsWith("\n> "))
  95. {
  96. text = text.Substring(0, text.Length - "\n> ".Length).TrimEnd();
  97. }
  98. return text;
  99. }
  100. }
  101. /// <summary>
  102. /// A text input transform that only trims the text.
  103. /// </summary>
  104. public class NaiveTextInputTransform : ITextTransform
  105. {
  106. /// <summary>
  107. ///
  108. /// </summary>
  109. public NaiveTextInputTransform()
  110. {
  111. }
  112. /// <inheritdoc />
  113. public string Transform(string text)
  114. {
  115. return text.Trim();
  116. }
  117. }
  118. /// <summary>
  119. /// A no-op text input transform.
  120. /// </summary>
  121. public class EmptyTextOutputStreamTransform : ITextStreamTransform
  122. {
  123. /// <inheritdoc />
  124. public IEnumerable<string> Transform(IEnumerable<string> tokens)
  125. {
  126. return tokens;
  127. }
  128. /// <inheritdoc />
  129. public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
  130. {
  131. return tokens;
  132. }
  133. }
  134. /// <summary>
  135. /// A text output transform that removes the keywords from the response.
  136. /// </summary>
  137. public class KeywordTextOutputStreamTransform : ITextStreamTransform
  138. {
  139. HashSet<string> _keywords;
  140. int _maxKeywordLength;
  141. bool _removeAllMatchedTokens;
  142. /// <summary>
  143. ///
  144. /// </summary>
  145. /// <param name="keywords">Keywords that you want to remove from the response.</param>
  146. /// <param name="redundancyLength">The extra length when searching for the keyword. For example, if your only keyword is "highlight",
  147. /// maybe the token you get is "\r\nhighligt". In this condition, if redundancyLength=0, the token cannot be successfully matched because the length of "\r\nhighligt" (10)
  148. /// has already exceeded the maximum length of the keywords (8). On the contrary, setting redundancyLengyh &gt;= 2 leads to successful match.
  149. /// The larger the redundancyLength is, the lower the processing speed. But as an experience, it won't introduce too much performance impact when redundancyLength &lt;= 5 </param>
  150. /// <param name="removeAllMatchedTokens">If set to true, when getting a matched keyword, all the related tokens will be removed. Otherwise only the part of keyword will be removed.</param>
  151. public KeywordTextOutputStreamTransform(IEnumerable<string> keywords, int redundancyLength = 3, bool removeAllMatchedTokens = false)
  152. {
  153. _keywords = new(keywords);
  154. _maxKeywordLength = _keywords.Max(x => x.Length) + redundancyLength;
  155. _maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength;
  156. _removeAllMatchedTokens = removeAllMatchedTokens;
  157. }
  158. /// <inheritdoc />
  159. public IEnumerable<string> Transform(IEnumerable<string> tokens)
  160. {
  161. var window = new Queue<string>();
  162. foreach (var s in tokens)
  163. {
  164. window.Enqueue(s);
  165. var current = string.Join("", window);
  166. if (_keywords.Any(x => current.Contains(x)))
  167. {
  168. var matchedKeyword = _keywords.First(x => current.Contains(x));
  169. int total = window.Count;
  170. for (int i = 0; i < total; i++)
  171. {
  172. window.Dequeue();
  173. }
  174. if (!_removeAllMatchedTokens)
  175. {
  176. yield return current.Replace(matchedKeyword, "");
  177. }
  178. }
  179. if (current.Length >= _maxKeywordLength)
  180. {
  181. if (_keywords.Any(x => current.Contains(x)))
  182. {
  183. var matchedKeyword = _keywords.First(x => current.Contains(x));
  184. int total = window.Count;
  185. for (int i = 0; i < total; i++)
  186. {
  187. window.Dequeue();
  188. }
  189. if (!_removeAllMatchedTokens)
  190. {
  191. yield return current.Replace(matchedKeyword, "");
  192. }
  193. }
  194. else
  195. {
  196. int total = window.Count;
  197. for (int i = 0; i < total; i++)
  198. {
  199. yield return window.Dequeue();
  200. }
  201. }
  202. }
  203. }
  204. int totalCount = window.Count;
  205. for (int i = 0; i < totalCount; i++)
  206. {
  207. yield return window.Dequeue();
  208. }
  209. }
  210. /// <inheritdoc />
  211. public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
  212. {
  213. var window = new Queue<string>();
  214. await foreach (var s in tokens)
  215. {
  216. window.Enqueue(s);
  217. var current = string.Join("", window);
  218. if (_keywords.Any(x => current.Contains(x)))
  219. {
  220. var matchedKeyword = _keywords.First(x => current.Contains(x));
  221. int total = window.Count;
  222. for (int i = 0; i < total; i++)
  223. {
  224. window.Dequeue();
  225. }
  226. if (!_removeAllMatchedTokens)
  227. {
  228. yield return current.Replace(matchedKeyword, "");
  229. }
  230. }
  231. if (current.Length >= _maxKeywordLength)
  232. {
  233. int total = window.Count;
  234. for (int i = 0; i < total; i++)
  235. {
  236. yield return window.Dequeue();
  237. }
  238. }
  239. }
  240. int totalCount = window.Count;
  241. for (int i = 0; i < totalCount; i++)
  242. {
  243. yield return window.Dequeue();
  244. }
  245. }
  246. }
  247. }
  248. }