You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TemplateTests.cs 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. using System.Text;
  2. using LLama.Common;
  3. using LLama.Native;
  4. namespace LLama.Unittest;
  5. public sealed class TemplateTests
  6. : IDisposable
  7. {
  8. private readonly LLamaWeights _model;
  9. public TemplateTests()
  10. {
  11. var @params = new ModelParams(Constants.GenerativeModelPath)
  12. {
  13. ContextSize = 1,
  14. GpuLayerCount = Constants.CIGpuLayerCount
  15. };
  16. _model = LLamaWeights.LoadFromFile(@params);
  17. }
  18. public void Dispose()
  19. {
  20. _model.Dispose();
  21. }
  22. [Fact]
  23. public void BasicTemplate()
  24. {
  25. var templater = new LLamaTemplate(_model);
  26. Assert.Equal(0, templater.Count);
  27. templater.Add("assistant", "hello");
  28. Assert.Equal(1, templater.Count);
  29. templater.Add("user", "world");
  30. Assert.Equal(2, templater.Count);
  31. templater.Add("assistant", "111");
  32. Assert.Equal(3, templater.Count);
  33. templater.Add("user", "aaa");
  34. Assert.Equal(4, templater.Count);
  35. templater.Add("assistant", "222");
  36. Assert.Equal(5, templater.Count);
  37. templater.Add("user", "bbb");
  38. Assert.Equal(6, templater.Count);
  39. templater.Add("assistant", "333");
  40. Assert.Equal(7, templater.Count);
  41. templater.Add("user", "ccc");
  42. Assert.Equal(8, templater.Count);
  43. // Call once with empty array to discover length
  44. var length = templater.Apply(Array.Empty<byte>());
  45. var dest = new byte[length];
  46. Assert.Equal(8, templater.Count);
  47. // Call again to get contents
  48. length = templater.Apply(dest);
  49. Assert.Equal(8, templater.Count);
  50. var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
  51. const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
  52. "<|im_start|>user\nworld<|im_end|>\n" +
  53. "<|im_start|>assistant\n" +
  54. "111<|im_end|>" +
  55. "\n<|im_start|>user\n" +
  56. "aaa<|im_end|>\n" +
  57. "<|im_start|>assistant\n" +
  58. "222<|im_end|>\n" +
  59. "<|im_start|>user\n" +
  60. "bbb<|im_end|>\n" +
  61. "<|im_start|>assistant\n" +
  62. "333<|im_end|>\n" +
  63. "<|im_start|>user\n" +
  64. "ccc<|im_end|>\n";
  65. Assert.Equal(expected, templateResult);
  66. }
  67. [Fact]
  68. public void CustomTemplate()
  69. {
  70. var templater = new LLamaTemplate("gemma");
  71. Assert.Equal(0, templater.Count);
  72. templater.Add("assistant", "hello");
  73. Assert.Equal(1, templater.Count);
  74. templater.Add("user", "world");
  75. Assert.Equal(2, templater.Count);
  76. templater.Add("assistant", "111");
  77. Assert.Equal(3, templater.Count);
  78. templater.Add("user", "aaa");
  79. Assert.Equal(4, templater.Count);
  80. // Call once with empty array to discover length
  81. var length = templater.Apply(Array.Empty<byte>());
  82. var dest = new byte[length];
  83. Assert.Equal(4, templater.Count);
  84. // Call again to get contents
  85. length = templater.Apply(dest);
  86. Assert.Equal(4, templater.Count);
  87. var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
  88. const string expected = "<start_of_turn>model\n" +
  89. "hello<end_of_turn>\n" +
  90. "<start_of_turn>user\n" +
  91. "world<end_of_turn>\n" +
  92. "<start_of_turn>model\n" +
  93. "111<end_of_turn>\n" +
  94. "<start_of_turn>user\n" +
  95. "aaa<end_of_turn>\n";
  96. Assert.Equal(expected, templateResult);
  97. }
  98. [Fact]
  99. public void BasicTemplateWithAddAssistant()
  100. {
  101. var templater = new LLamaTemplate(_model)
  102. {
  103. AddAssistant = true,
  104. };
  105. Assert.Equal(0, templater.Count);
  106. templater.Add("assistant", "hello");
  107. Assert.Equal(1, templater.Count);
  108. templater.Add("user", "world");
  109. Assert.Equal(2, templater.Count);
  110. templater.Add("assistant", "111");
  111. Assert.Equal(3, templater.Count);
  112. templater.Add("user", "aaa");
  113. Assert.Equal(4, templater.Count);
  114. templater.Add("assistant", "222");
  115. Assert.Equal(5, templater.Count);
  116. templater.Add("user", "bbb");
  117. Assert.Equal(6, templater.Count);
  118. templater.Add("assistant", "333");
  119. Assert.Equal(7, templater.Count);
  120. templater.Add("user", "ccc");
  121. Assert.Equal(8, templater.Count);
  122. // Call once with empty array to discover length
  123. var length = templater.Apply(Array.Empty<byte>());
  124. var dest = new byte[length];
  125. Assert.Equal(8, templater.Count);
  126. // Call again to get contents
  127. length = templater.Apply(dest);
  128. Assert.Equal(8, templater.Count);
  129. var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
  130. const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
  131. "<|im_start|>user\nworld<|im_end|>\n" +
  132. "<|im_start|>assistant\n" +
  133. "111<|im_end|>" +
  134. "\n<|im_start|>user\n" +
  135. "aaa<|im_end|>\n" +
  136. "<|im_start|>assistant\n" +
  137. "222<|im_end|>\n" +
  138. "<|im_start|>user\n" +
  139. "bbb<|im_end|>\n" +
  140. "<|im_start|>assistant\n" +
  141. "333<|im_end|>\n" +
  142. "<|im_start|>user\n" +
  143. "ccc<|im_end|>\n" +
  144. "<|im_start|>assistant\n";
  145. Assert.Equal(expected, templateResult);
  146. }
  147. [Fact]
  148. public void GetOutOfRangeThrows()
  149. {
  150. var templater = new LLamaTemplate(_model);
  151. Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]);
  152. templater.Add("assistant", "1");
  153. templater.Add("user", "2");
  154. Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]);
  155. Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]);
  156. }
  157. [Fact]
  158. public void RemoveMid()
  159. {
  160. var templater = new LLamaTemplate(_model);
  161. templater.Add("assistant", "1");
  162. templater.Add("user", "2");
  163. templater.Add("assistant", "3");
  164. templater.Add("user", "4a");
  165. templater.Add("user", "4b");
  166. templater.Add("assistant", "5");
  167. Assert.Equal(("user", "4a"), templater[3]);
  168. Assert.Equal(("assistant", "5"), templater[5]);
  169. Assert.Equal(6, templater.Count);
  170. templater.RemoveAt(3);
  171. Assert.Equal(5, templater.Count);
  172. Assert.Equal(("user", "4b"), templater[3]);
  173. Assert.Equal(("assistant", "5"), templater[4]);
  174. }
  175. [Fact]
  176. public void RemoveLast()
  177. {
  178. var templater = new LLamaTemplate(_model);
  179. templater.Add("assistant", "1");
  180. templater.Add("user", "2");
  181. templater.Add("assistant", "3");
  182. templater.Add("user", "4a");
  183. templater.Add("user", "4b");
  184. templater.Add("assistant", "5");
  185. Assert.Equal(6, templater.Count);
  186. templater.RemoveAt(5);
  187. Assert.Equal(5, templater.Count);
  188. Assert.Equal(("user", "4b"), templater[4]);
  189. }
  190. [Fact]
  191. public void RemoveOutOfRange()
  192. {
  193. var templater = new LLamaTemplate(_model);
  194. Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0));
  195. templater.Add("assistant", "1");
  196. templater.Add("user", "2");
  197. Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
  198. Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
  199. }
  200. }