You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTransforms.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Text.Json.Serialization;
  7. namespace LLama
  8. {
  9. /// <summary>
  10. /// A class that contains all the transforms provided internally by LLama.
  11. /// </summary>
  12. public class LLamaTransforms
  13. {
  14. /// <summary>
  15. /// The default history transform.
  16. /// Uses plain text with the following format:
  17. /// [Author]: [Message]
  18. /// </summary>
  19. public class DefaultHistoryTransform : IHistoryTransform
  20. {
  21. private const string defaultUserName = "User";
  22. private const string defaultAssistantName = "Assistant";
  23. private const string defaultSystemName = "System";
  24. private const string defaultUnknownName = "??";
  25. private readonly string _userName;
  26. private readonly string _assistantName;
  27. private readonly string _systemName;
  28. private readonly string _unknownName;
  29. private readonly bool _isInstructMode;
  30. public string UserName => _userName;
  31. public string AssistantName => _assistantName;
  32. public string SystemName => _systemName;
  33. public string UnknownName => _unknownName;
  34. public bool IsInstructMode => _isInstructMode;
  35. /// <summary>
  36. ///
  37. /// </summary>
  38. /// <param name="userName"></param>
  39. /// <param name="assistantName"></param>
  40. /// <param name="systemName"></param>
  41. /// <param name="unknownName"></param>
  42. /// <param name="isInstructMode"></param>
  43. public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
  44. string? systemName = null, string? unknownName = null, bool isInstructMode = false)
  45. {
  46. _userName = userName ?? defaultUserName;
  47. _assistantName = assistantName ?? defaultAssistantName;
  48. _systemName = systemName ?? defaultSystemName;
  49. _unknownName = unknownName ?? defaultUnknownName;
  50. _isInstructMode = isInstructMode;
  51. }
  52. /// <inheritdoc />
  53. public IHistoryTransform Clone()
  54. {
  55. return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode);
  56. }
  57. /// <inheritdoc />
  58. public virtual string HistoryToText(ChatHistory history)
  59. {
  60. StringBuilder sb = new();
  61. foreach (var message in history.Messages)
  62. {
  63. if (message.AuthorRole == AuthorRole.User)
  64. {
  65. sb.AppendLine($"{_userName}: {message.Content}");
  66. }
  67. else if (message.AuthorRole == AuthorRole.System)
  68. {
  69. sb.AppendLine($"{_systemName}: {message.Content}");
  70. }
  71. else if (message.AuthorRole == AuthorRole.Unknown)
  72. {
  73. sb.AppendLine($"{_unknownName}: {message.Content}");
  74. }
  75. else if (message.AuthorRole == AuthorRole.Assistant)
  76. {
  77. sb.AppendLine($"{_assistantName}: {message.Content}");
  78. }
  79. }
  80. return sb.ToString();
  81. }
  82. /// <inheritdoc />
  83. public virtual ChatHistory TextToHistory(AuthorRole role, string text)
  84. {
  85. ChatHistory history = new ChatHistory();
  86. history.AddMessage(role, TrimNamesFromText(text, role));
  87. return history;
  88. }
  89. /// <summary>
  90. /// Drop the name at the beginning and the end of the text.
  91. /// </summary>
  92. /// <param name="text"></param>
  93. /// <param name="role"></param>
  94. /// <returns></returns>
  95. public virtual string TrimNamesFromText(string text, AuthorRole role)
  96. {
  97. if (role == AuthorRole.User && text.StartsWith($"{_userName}:"))
  98. {
  99. text = text.Substring($"{_userName}:".Length).TrimStart();
  100. }
  101. else if (role == AuthorRole.Assistant && text.EndsWith($"{_assistantName}:"))
  102. {
  103. text = text.Substring(0, text.Length - $"{_assistantName}:".Length).TrimEnd();
  104. }
  105. if (_isInstructMode && role == AuthorRole.Assistant && text.EndsWith("\n> "))
  106. {
  107. text = text.Substring(0, text.Length - "\n> ".Length).TrimEnd();
  108. }
  109. return text;
  110. }
  111. }
  112. /// <summary>
  113. /// A text input transform that only trims the text.
  114. /// </summary>
  115. public class NaiveTextInputTransform
  116. : ITextTransform
  117. {
  118. /// <inheritdoc />
  119. public string Transform(string text)
  120. {
  121. return text.Trim();
  122. }
  123. /// <inheritdoc />
  124. public ITextTransform Clone()
  125. {
  126. return new NaiveTextInputTransform();
  127. }
  128. }
  129. /// <summary>
  130. /// A no-op text input transform.
  131. /// </summary>
  132. public class EmptyTextOutputStreamTransform
  133. : ITextStreamTransform
  134. {
  135. /// <inheritdoc />
  136. public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
  137. {
  138. return tokens;
  139. }
  140. /// <inheritdoc />
  141. public ITextStreamTransform Clone()
  142. {
  143. return new EmptyTextOutputStreamTransform();
  144. }
  145. }
  146. /// <summary>
  147. /// A text output transform that removes the keywords from the response.
  148. /// </summary>
  149. public class KeywordTextOutputStreamTransform : ITextStreamTransform
  150. {
  151. private readonly HashSet<string> _keywords;
  152. private readonly int _maxKeywordLength;
  153. private readonly bool _removeAllMatchedTokens;
  154. /// <summary>
  155. /// Keywords that you want to remove from the response.
  156. /// This property is used for JSON serialization.
  157. /// </summary>
  158. [JsonPropertyName("keywords")]
  159. public HashSet<string> Keywords => _keywords;
  160. /// <summary>
  161. /// Maximum length of the keywords.
  162. /// This property is used for JSON serialization.
  163. /// </summary>
  164. [JsonPropertyName("maxKeywordLength")]
  165. public int MaxKeywordLength => _maxKeywordLength;
  166. /// <summary>
  167. /// If set to true, when getting a matched keyword, all the related tokens will be removed.
  168. /// Otherwise only the part of keyword will be removed.
  169. /// This property is used for JSON serialization.
  170. /// </summary>
  171. [JsonPropertyName("removeAllMatchedTokens")]
  172. public bool RemoveAllMatchedTokens => _removeAllMatchedTokens;
  173. /// <summary>
  174. /// JSON constructor.
  175. /// </summary>
  176. [JsonConstructor]
  177. public KeywordTextOutputStreamTransform(
  178. HashSet<string> keywords,
  179. int maxKeywordLength,
  180. bool removeAllMatchedTokens)
  181. {
  182. _keywords = new(keywords);
  183. _maxKeywordLength = maxKeywordLength;
  184. _removeAllMatchedTokens = removeAllMatchedTokens;
  185. }
  186. /// <summary>
  187. ///
  188. /// </summary>
  189. /// <param name="keywords">Keywords that you want to remove from the response.</param>
  190. /// <param name="redundancyLength">The extra length when searching for the keyword. For example, if your only keyword is "highlight",
  191. /// maybe the token you get is "\r\nhighligt". In this condition, if redundancyLength=0, the token cannot be successfully matched because the length of "\r\nhighligt" (10)
  192. /// has already exceeded the maximum length of the keywords (8). On the contrary, setting redundancyLengyh &gt;= 2 leads to successful match.
  193. /// The larger the redundancyLength is, the lower the processing speed. But as an experience, it won't introduce too much performance impact when redundancyLength &lt;= 5 </param>
  194. /// <param name="removeAllMatchedTokens">If set to true, when getting a matched keyword, all the related tokens will be removed. Otherwise only the part of keyword will be removed.</param>
  195. public KeywordTextOutputStreamTransform(IEnumerable<string> keywords, int redundancyLength = 3, bool removeAllMatchedTokens = false)
  196. {
  197. _keywords = new(keywords);
  198. _maxKeywordLength = _keywords.Max(x => x.Length) + redundancyLength;
  199. _maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength;
  200. _removeAllMatchedTokens = removeAllMatchedTokens;
  201. }
  202. /// <inheritdoc />
  203. public ITextStreamTransform Clone()
  204. {
  205. return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens);
  206. }
  207. /// <inheritdoc />
  208. public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
  209. {
  210. var window = new Queue<string>();
  211. await foreach (var s in tokens)
  212. {
  213. window.Enqueue(s);
  214. var current = string.Join("", window);
  215. if (_keywords.Any(x => current.Contains(x)))
  216. {
  217. var matchedKeywords = _keywords.Where(x => current.Contains(x));
  218. int total = window.Count;
  219. for (int i = 0; i < total; i++)
  220. {
  221. window.Dequeue();
  222. }
  223. if (!_removeAllMatchedTokens)
  224. {
  225. foreach(var keyword in matchedKeywords)
  226. {
  227. current = current.Replace(keyword, "");
  228. }
  229. yield return current;
  230. }
  231. }
  232. if (current.Length >= _maxKeywordLength)
  233. {
  234. int total = window.Count;
  235. for (int i = 0; i < total; i++)
  236. {
  237. yield return window.Dequeue();
  238. }
  239. }
  240. }
  241. int totalCount = window.Count;
  242. for (int i = 0; i < totalCount; i++)
  243. {
  244. yield return window.Dequeue();
  245. }
  246. }
  247. }
  248. }
  249. }