You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaTemplate.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Text;
  6. using LLama.Native;
  7. namespace LLama;
  8. /// <summary>
  9. /// Converts a sequence of messages into text according to a model template
  10. /// </summary>
  11. public sealed class LLamaTemplate
  12. {
  13. #region private state
  14. /// <summary>
  15. /// The model this template is for. May be null if a custom template was supplied to the constructor.
  16. /// </summary>
  17. private readonly SafeLlamaModelHandle? _model;
  18. /// <summary>
  19. /// Custom template. May be null if a model was supplied to the constructor.
  20. /// </summary>
  21. private readonly byte[]? _customTemplate;
  22. /// <summary>
  23. /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
  24. /// </summary>
  25. private readonly Dictionary<string, ReadOnlyMemory<byte>> _roleCache = new();
  26. /// <summary>
  27. /// Array of messages. The <see cref="Count"/> property indicates how many messages there are
  28. /// </summary>
  29. private Message[] _messages = new Message[4];
  30. /// <summary>
  31. /// Backing field for <see cref="AddAssistant"/>
  32. /// </summary>
  33. private bool _addAssistant;
  34. /// <summary>
  35. /// Temporary array of messages in the format llama.cpp needs, used when applying the template
  36. /// </summary>
  37. private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4];
  38. /// <summary>
  39. /// Indicates how many bytes are in <see cref="_result"/> array
  40. /// </summary>
  41. private int _resultLength;
  42. /// <summary>
  43. /// Result bytes of last call to <see cref="Apply"/>
  44. /// </summary>
  45. private byte[] _result = Array.Empty<byte>();
  46. /// <summary>
  47. /// Indicates if this template has been modified and needs regenerating
  48. /// </summary>
  49. private bool _dirty = true;
  50. #endregion
  51. #region properties
  52. /// <summary>
  53. /// Number of messages added to this template
  54. /// </summary>
  55. public int Count { get; private set; }
  56. /// <summary>
  57. /// Get the message at the given index
  58. /// </summary>
  59. /// <param name="index"></param>
  60. /// <returns></returns>
  61. /// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
  62. public (string role, string content) this[int index]
  63. {
  64. get
  65. {
  66. if (index < 0)
  67. throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0");
  68. if (index >= Count)
  69. throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
  70. return (_messages[index].Role, _messages[index].Content);
  71. }
  72. }
  73. /// <summary>
  74. /// Whether to end the prompt with the token(s) that indicate the start of an assistant message.
  75. /// </summary>
  76. public bool AddAssistant
  77. {
  78. get => _addAssistant;
  79. set
  80. {
  81. if (value != _addAssistant)
  82. {
  83. _dirty = true;
  84. _addAssistant = value;
  85. }
  86. }
  87. }
  88. #endregion
  89. #region construction
  90. /// <summary>
  91. /// Construct a new template, using the default model template
  92. /// </summary>
  93. /// <param name="model"></param>
  94. public LLamaTemplate(SafeLlamaModelHandle model)
  95. {
  96. _model = model;
  97. }
  98. /// <summary>
  99. /// Construct a new template, using the default model template
  100. /// </summary>
  101. /// <param name="weights"></param>
  102. public LLamaTemplate(LLamaWeights weights)
  103. : this(weights.NativeHandle)
  104. {
  105. }
  106. /// <summary>
  107. /// Construct a new template, using a custom template.
  108. /// </summary>
  109. /// <remarks>Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template</remarks>
  110. /// <param name="customTemplate"></param>
  111. public LLamaTemplate(string customTemplate)
  112. {
  113. _customTemplate = Encoding.UTF8.GetBytes(customTemplate + "\0");
  114. }
  115. #endregion
  116. /// <summary>
  117. /// Add a new message to the end of this template
  118. /// </summary>
  119. /// <param name="role"></param>
  120. /// <param name="content"></param>
  121. public void Add(string role, string content)
  122. {
  123. // Expand messages array if necessary
  124. if (Count == _messages.Length)
  125. Array.Resize(ref _messages, _messages.Length * 2);
  126. // Add message
  127. _messages[Count] = new Message(role, content, _roleCache);
  128. Count++;
  129. // Mark as dirty to ensure template is recalculated
  130. _dirty = true;
  131. }
  132. /// <summary>
  133. /// Remove a message at the given index
  134. /// </summary>
  135. /// <param name="index"></param>
  136. public void RemoveAt(int index)
  137. {
  138. if (index < 0)
  139. throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
  140. if (index >= Count)
  141. throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count");
  142. _dirty = true;
  143. Count--;
  144. // Copy all items after index down by one
  145. if (index < Count)
  146. Array.Copy(_messages, index + 1, _messages, index, Count - index);
  147. _messages[Count] = default;
  148. }
  149. /// <summary>
  150. /// Apply the template to the messages and write it into the output buffer
  151. /// </summary>
  152. /// <param name="dest">Destination to write template bytes into</param>
  153. /// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns>
  154. public int Apply(Memory<byte> dest)
  155. {
  156. // Recalculate template if necessary
  157. if (_dirty)
  158. {
  159. _dirty = false;
  160. using var group = new GroupDisposable();
  161. unsafe
  162. {
  163. // Convert all the messages
  164. var totalInputBytes = 0;
  165. if (_nativeChatMessages.Length < _messages.Length)
  166. Array.Resize(ref _nativeChatMessages, _messages.Length);
  167. for (var i = 0; i < Count; i++)
  168. {
  169. ref var m = ref _messages[i];
  170. totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
  171. // Pin byte arrays in place
  172. var r = m.RoleBytes.Pin();
  173. group.Add(r);
  174. var c = m.ContentBytes.Pin();
  175. group.Add(c);
  176. _nativeChatMessages[i] = new LLamaChatMessage
  177. {
  178. role = (byte*)r.Pointer,
  179. content = (byte*)c.Pointer
  180. };
  181. }
  182. // Get an array that's twice as large as the amount of input, hopefully that's large enough!
  183. var output = ArrayPool<byte>.Shared.Rent(Math.Max(32, totalInputBytes * 2));
  184. try
  185. {
  186. // Run templater and discover true length
  187. var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
  188. // If length was too big for output buffer run it again
  189. if (outputLength > output.Length)
  190. {
  191. // Array was too small, rent another one that's exactly the size needed
  192. ArrayPool<byte>.Shared.Return(output, true);
  193. output = ArrayPool<byte>.Shared.Rent(outputLength);
  194. // Run again, but this time with an output that is definitely large enough
  195. ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
  196. }
  197. // Grow result buffer if necessary
  198. if (_result.Length < outputLength)
  199. Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength));
  200. // Copy to result buffer
  201. output.AsSpan(0, outputLength).CopyTo(_result);
  202. _resultLength = outputLength;
  203. }
  204. finally
  205. {
  206. ArrayPool<byte>.Shared.Return(output, true);
  207. }
  208. }
  209. }
  210. // Now that the template has been applied and is in the result buffer, copy it to the dest
  211. _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
  212. return _resultLength;
  213. unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output)
  214. {
  215. fixed (byte* customTemplatePtr = _customTemplate)
  216. fixed (byte* outputPtr = output)
  217. fixed (LLamaChatMessage* messagesPtr = messages)
  218. {
  219. return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length);
  220. }
  221. }
  222. }
  223. /// <summary>
  224. /// A message that has been added to the template, contains role and content converted into UTF8 bytes.
  225. /// </summary>
  226. private readonly record struct Message
  227. {
  228. public string Role { get; }
  229. public string Content { get; }
  230. public ReadOnlyMemory<byte> RoleBytes { get; }
  231. public ReadOnlyMemory<byte> ContentBytes { get; }
  232. public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache)
  233. {
  234. Role = role;
  235. Content = content;
  236. // Get bytes for role from cache
  237. if (!roleCache.TryGetValue(role, out var roleBytes))
  238. {
  239. // Convert role. Add one to length so there is a null byte at the end.
  240. var rArr = new byte[Encoding.UTF8.GetByteCount(role) + 1];
  241. var encodedRoleLength = Encoding.UTF8.GetBytes(role.AsSpan(), rArr);
  242. Debug.Assert(rArr.Length == encodedRoleLength + 1);
  243. // Add to cache for future use.
  244. // To ensure the cache cannot grow infinitely add a hard limit to size.
  245. if (roleCache.Count < 128)
  246. {
  247. roleCache.Add(role, rArr);
  248. roleBytes = rArr;
  249. }
  250. }
  251. RoleBytes = roleBytes;
  252. // Convert content. Add one to length so there is a null byte at the end.
  253. var contentArray = new byte[Encoding.UTF8.GetByteCount(content) + 1];
  254. var encodedContentLength = Encoding.UTF8.GetBytes(content.AsSpan(), contentArray);
  255. Debug.Assert(contentArray.Length == encodedContentLength + 1);
  256. ContentBytes = contentArray;
  257. }
  258. }
  259. }