You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTemplate.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Text;
  6. using LLama.Native;
  7. namespace LLama;
  8. /// <summary>
  9. /// Converts a sequence of messages into text according to a model template
  10. /// </summary>
  11. public sealed class LLamaTemplate
  12. {
  13. #region private state
  14. private static readonly Encoding Encoding = Encoding.UTF8;
  15. /// <summary>
  16. /// The model this template is for. May be null if a custom template was supplied to the constructor.
  17. /// </summary>
  18. private readonly SafeLlamaModelHandle? _model;
  19. /// <summary>
  20. /// Custom template. May be null if a model was supplied to the constructor.
  21. /// </summary>
  22. private readonly byte[]? _customTemplate;
  23. /// <summary>
  24. /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
  25. /// </summary>
  26. private readonly Dictionary<string, ReadOnlyMemory<byte>> _roleCache = new();
  27. /// <summary>
  28. /// Array of messages. The <see cref="Count"/> property indicates how many messages there are
  29. /// </summary>
  30. private TextMessage?[] _messages = new TextMessage[4];
  31. /// <summary>
  32. /// Backing field for <see cref="AddAssistant"/>
  33. /// </summary>
  34. private bool _addAssistant;
  35. /// <summary>
  36. /// Temporary array of messages in the format llama.cpp needs, used when applying the template
  37. /// </summary>
  38. private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4];
  39. /// <summary>
  40. /// Indicates how many bytes are in <see cref="_result"/> array
  41. /// </summary>
  42. private int _resultLength;
  43. /// <summary>
  44. /// Result bytes of last call to <see cref="Apply"/>
  45. /// </summary>
  46. private byte[] _result = Array.Empty<byte>();
  47. /// <summary>
  48. /// Indicates if this template has been modified and needs regenerating
  49. /// </summary>
  50. private bool _dirty = true;
  51. #endregion
  52. #region properties
  53. /// <summary>
  54. /// Number of messages added to this template
  55. /// </summary>
  56. public int Count { get; private set; }
  57. /// <summary>
  58. /// Get the message at the given index
  59. /// </summary>
  60. /// <param name="index"></param>
  61. /// <returns></returns>
  62. /// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
  63. public TextMessage this[int index]
  64. {
  65. get
  66. {
  67. if (index < 0)
  68. throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0");
  69. if (index >= Count)
  70. throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
  71. return _messages[index]!;
  72. }
  73. }
  74. /// <summary>
  75. /// Whether to end the prompt with the token(s) that indicate the start of an assistant message.
  76. /// </summary>
  77. public bool AddAssistant
  78. {
  79. get => _addAssistant;
  80. set
  81. {
  82. if (value != _addAssistant)
  83. {
  84. _dirty = true;
  85. _addAssistant = value;
  86. }
  87. }
  88. }
  89. #endregion
  90. #region construction
  91. /// <summary>
  92. /// Construct a new template, using the default model template
  93. /// </summary>
  94. /// <param name="model"></param>
  95. public LLamaTemplate(SafeLlamaModelHandle model)
  96. {
  97. _model = model;
  98. }
  99. /// <summary>
  100. /// Construct a new template, using the default model template
  101. /// </summary>
  102. /// <param name="weights"></param>
  103. public LLamaTemplate(LLamaWeights weights)
  104. : this(weights.NativeHandle)
  105. {
  106. }
  107. /// <summary>
  108. /// Construct a new template, using a custom template.
  109. /// </summary>
  110. /// <remarks>Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template</remarks>
  111. /// <param name="customTemplate"></param>
  112. public LLamaTemplate(string customTemplate)
  113. {
  114. _customTemplate = Encoding.GetBytes(customTemplate + "\0");
  115. }
  116. #endregion
  117. #region modify
  118. /// <summary>
  119. /// Add a new message to the end of this template
  120. /// </summary>
  121. /// <param name="role"></param>
  122. /// <param name="content"></param>
  123. /// <returns>This template, for chaining calls.</returns>
  124. public LLamaTemplate Add(string role, string content)
  125. {
  126. return Add(new TextMessage(role, content, _roleCache));
  127. }
  128. /// <summary>
  129. /// Add a new message to the end of this template
  130. /// </summary>
  131. /// <param name="message"></param>
  132. /// <returns>This template, for chaining calls.</returns>
  133. public LLamaTemplate Add(TextMessage message)
  134. {
  135. // Expand messages array if necessary
  136. if (Count == _messages.Length)
  137. Array.Resize(ref _messages, _messages.Length * 2);
  138. // Add message
  139. _messages[Count] = message;
  140. Count++;
  141. // Mark as dirty to ensure template is recalculated
  142. _dirty = true;
  143. return this;
  144. }
  145. /// <summary>
  146. /// Remove a message at the given index
  147. /// </summary>
  148. /// <param name="index"></param>
  149. /// <returns>This template, for chaining calls.</returns>
  150. public LLamaTemplate RemoveAt(int index)
  151. {
  152. if (index < 0)
  153. throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
  154. if (index >= Count)
  155. throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count");
  156. _dirty = true;
  157. Count--;
  158. // Copy all items after index down by one
  159. if (index < Count)
  160. Array.Copy(_messages, index + 1, _messages, index, Count - index);
  161. _messages[Count] = default;
  162. return this;
  163. }
  164. #endregion
  165. /// <summary>
  166. /// Apply the template to the messages and write it into the output buffer
  167. /// </summary>
  168. /// <param name="dest">Destination to write template bytes into</param>
  169. /// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns>
  170. public int Apply(Memory<byte> dest)
  171. {
  172. // Recalculate template if necessary
  173. if (_dirty)
  174. {
  175. _dirty = false;
  176. using var group = new GroupDisposable();
  177. unsafe
  178. {
  179. // Convert all the messages
  180. var totalInputBytes = 0;
  181. if (_nativeChatMessages.Length < _messages.Length)
  182. Array.Resize(ref _nativeChatMessages, _messages.Length);
  183. for (var i = 0; i < Count; i++)
  184. {
  185. ref var m = ref _messages[i]!;
  186. Debug.Assert(m != null);
  187. totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
  188. // Pin byte arrays in place
  189. var r = m.RoleBytes.Pin();
  190. group.Add(r);
  191. var c = m.ContentBytes.Pin();
  192. group.Add(c);
  193. _nativeChatMessages[i] = new LLamaChatMessage
  194. {
  195. role = (byte*)r.Pointer,
  196. content = (byte*)c.Pointer
  197. };
  198. }
  199. // Get an array that's twice as large as the amount of input, hopefully that's large enough!
  200. var output = ArrayPool<byte>.Shared.Rent(Math.Max(32, totalInputBytes * 2));
  201. try
  202. {
  203. // Run templater and discover true length
  204. var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
  205. // If length was too big for output buffer run it again
  206. if (outputLength > output.Length)
  207. {
  208. // Array was too small, rent another one that's exactly the size needed
  209. ArrayPool<byte>.Shared.Return(output, true);
  210. output = ArrayPool<byte>.Shared.Rent(outputLength);
  211. // Run again, but this time with an output that is definitely large enough
  212. ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
  213. }
  214. // Grow result buffer if necessary
  215. if (_result.Length < outputLength)
  216. Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength));
  217. // Copy to result buffer
  218. output.AsSpan(0, outputLength).CopyTo(_result);
  219. _resultLength = outputLength;
  220. }
  221. finally
  222. {
  223. ArrayPool<byte>.Shared.Return(output, true);
  224. }
  225. }
  226. }
  227. // Now that the template has been applied and is in the result buffer, copy it to the dest
  228. _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
  229. return _resultLength;
  230. unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output)
  231. {
  232. fixed (byte* customTemplatePtr = _customTemplate)
  233. fixed (byte* outputPtr = output)
  234. fixed (LLamaChatMessage* messagesPtr = messages)
  235. {
  236. return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length);
  237. }
  238. }
  239. }
  240. /// <summary>
  241. /// A message that has been added to a template
  242. /// </summary>
  243. public sealed class TextMessage
  244. {
  245. /// <summary>
  246. /// The "role" string for this message
  247. /// </summary>
  248. public string Role { get; }
  249. /// <summary>
  250. /// The text content of this message
  251. /// </summary>
  252. public string Content { get; }
  253. internal ReadOnlyMemory<byte> RoleBytes { get; }
  254. internal ReadOnlyMemory<byte> ContentBytes { get; }
  255. internal TextMessage(string role, string content, IDictionary<string, ReadOnlyMemory<byte>> roleCache)
  256. {
  257. Role = role;
  258. Content = content;
  259. // Get bytes for role from cache
  260. if (!roleCache.TryGetValue(role, out var roleBytes))
  261. {
  262. // Convert role. Add one to length so there is a null byte at the end.
  263. var rArr = new byte[Encoding.GetByteCount(role) + 1];
  264. var encodedRoleLength = Encoding.GetBytes(role.AsSpan(), rArr);
  265. Debug.Assert(rArr.Length == encodedRoleLength + 1);
  266. // Add to cache for future use.
  267. // To ensure the cache cannot grow infinitely add a hard limit to size.
  268. if (roleCache.Count < 128)
  269. {
  270. roleCache.Add(role, rArr);
  271. roleBytes = rArr;
  272. }
  273. }
  274. RoleBytes = roleBytes;
  275. // Convert content. Add one to length so there is a null byte at the end.
  276. var contentArray = new byte[Encoding.GetByteCount(content) + 1];
  277. var encodedContentLength = Encoding.GetBytes(content.AsSpan(), contentArray);
  278. Debug.Assert(contentArray.Length == encodedContentLength + 1);
  279. ContentBytes = contentArray;
  280. }
  281. /// <summary>
  282. /// Deconstruct this message into role and content
  283. /// </summary>
  284. /// <param name="role"></param>
  285. /// <param name="content"></param>
  286. public void Deconstruct(out string role, out string content)
  287. {
  288. role = Role;
  289. content = Content;
  290. }
  291. }
  292. }