|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- using System.Text;
- using LLama.Common;
- using LLama.Native;
-
- namespace LLama.Unittest;
-
- public sealed class TemplateTests
- : IDisposable
- {
- private readonly LLamaWeights _model;
-
- public TemplateTests()
- {
- var @params = new ModelParams(Constants.GenerativeModelPath)
- {
- ContextSize = 1,
- GpuLayerCount = Constants.CIGpuLayerCount
- };
- _model = LLamaWeights.LoadFromFile(@params);
- }
-
- public void Dispose()
- {
- _model.Dispose();
- }
-
- [Fact]
- public void BasicTemplate()
- {
- var templater = new LLamaTemplate(_model);
-
- Assert.Equal(0, templater.Count);
- templater.Add("assistant", "hello");
- Assert.Equal(1, templater.Count);
- templater.Add("user", "world");
- Assert.Equal(2, templater.Count);
- templater.Add("assistant", "111");
- Assert.Equal(3, templater.Count);
- templater.Add("user", "aaa");
- Assert.Equal(4, templater.Count);
- templater.Add("assistant", "222");
- Assert.Equal(5, templater.Count);
- templater.Add("user", "bbb");
- Assert.Equal(6, templater.Count);
- templater.Add("assistant", "333");
- Assert.Equal(7, templater.Count);
- templater.Add("user", "ccc");
- Assert.Equal(8, templater.Count);
-
- // Call once with empty array to discover length
- var length = templater.Apply(Array.Empty<byte>());
- var dest = new byte[length];
-
- Assert.Equal(8, templater.Count);
-
- // Call again to get contents
- length = templater.Apply(dest);
-
- Assert.Equal(8, templater.Count);
-
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
- const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
- "<|im_start|>user\nworld<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "111<|im_end|>" +
- "\n<|im_start|>user\n" +
- "aaa<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "222<|im_end|>\n" +
- "<|im_start|>user\n" +
- "bbb<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "333<|im_end|>\n" +
- "<|im_start|>user\n" +
- "ccc<|im_end|>\n";
-
- Assert.Equal(expected, templateResult);
- }
-
- [Fact]
- public void CustomTemplate()
- {
- var templater = new LLamaTemplate("gemma");
-
- Assert.Equal(0, templater.Count);
- templater.Add("assistant", "hello");
- Assert.Equal(1, templater.Count);
- templater.Add("user", "world");
- Assert.Equal(2, templater.Count);
- templater.Add("assistant", "111");
- Assert.Equal(3, templater.Count);
- templater.Add("user", "aaa");
- Assert.Equal(4, templater.Count);
-
- // Call once with empty array to discover length
- var length = templater.Apply(Array.Empty<byte>());
- var dest = new byte[length];
-
- Assert.Equal(4, templater.Count);
-
- // Call again to get contents
- length = templater.Apply(dest);
-
- Assert.Equal(4, templater.Count);
-
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
- const string expected = "<start_of_turn>model\n" +
- "hello<end_of_turn>\n" +
- "<start_of_turn>user\n" +
- "world<end_of_turn>\n" +
- "<start_of_turn>model\n" +
- "111<end_of_turn>\n" +
- "<start_of_turn>user\n" +
- "aaa<end_of_turn>\n";
-
- Assert.Equal(expected, templateResult);
- }
-
- [Fact]
- public void BasicTemplateWithAddAssistant()
- {
- var templater = new LLamaTemplate(_model)
- {
- AddAssistant = true,
- };
-
- Assert.Equal(0, templater.Count);
- templater.Add("assistant", "hello");
- Assert.Equal(1, templater.Count);
- templater.Add("user", "world");
- Assert.Equal(2, templater.Count);
- templater.Add("assistant", "111");
- Assert.Equal(3, templater.Count);
- templater.Add("user", "aaa");
- Assert.Equal(4, templater.Count);
- templater.Add("assistant", "222");
- Assert.Equal(5, templater.Count);
- templater.Add("user", "bbb");
- Assert.Equal(6, templater.Count);
- templater.Add("assistant", "333");
- Assert.Equal(7, templater.Count);
- templater.Add("user", "ccc");
- Assert.Equal(8, templater.Count);
-
- // Call once with empty array to discover length
- var length = templater.Apply(Array.Empty<byte>());
- var dest = new byte[length];
-
- Assert.Equal(8, templater.Count);
-
- // Call again to get contents
- length = templater.Apply(dest);
-
- Assert.Equal(8, templater.Count);
-
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
- const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
- "<|im_start|>user\nworld<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "111<|im_end|>" +
- "\n<|im_start|>user\n" +
- "aaa<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "222<|im_end|>\n" +
- "<|im_start|>user\n" +
- "bbb<|im_end|>\n" +
- "<|im_start|>assistant\n" +
- "333<|im_end|>\n" +
- "<|im_start|>user\n" +
- "ccc<|im_end|>\n" +
- "<|im_start|>assistant\n";
-
- Assert.Equal(expected, templateResult);
- }
-
- [Fact]
- public void GetOutOfRangeThrows()
- {
- var templater = new LLamaTemplate(_model);
-
- Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]);
-
- templater.Add("assistant", "1");
- templater.Add("user", "2");
-
- Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]);
- Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]);
- }
-
- [Fact]
- public void RemoveMid()
- {
- var templater = new LLamaTemplate(_model);
-
- templater.Add("assistant", "1");
- templater.Add("user", "2");
- templater.Add("assistant", "3");
- templater.Add("user", "4a");
- templater.Add("user", "4b");
- templater.Add("assistant", "5");
-
- Assert.Equal(("user", "4a"), templater[3]);
- Assert.Equal(("assistant", "5"), templater[5]);
-
- Assert.Equal(6, templater.Count);
- templater.RemoveAt(3);
- Assert.Equal(5, templater.Count);
-
- Assert.Equal(("user", "4b"), templater[3]);
- Assert.Equal(("assistant", "5"), templater[4]);
- }
-
- [Fact]
- public void RemoveLast()
- {
- var templater = new LLamaTemplate(_model);
-
- templater.Add("assistant", "1");
- templater.Add("user", "2");
- templater.Add("assistant", "3");
- templater.Add("user", "4a");
- templater.Add("user", "4b");
- templater.Add("assistant", "5");
-
- Assert.Equal(6, templater.Count);
- templater.RemoveAt(5);
- Assert.Equal(5, templater.Count);
-
- Assert.Equal(("user", "4b"), templater[4]);
- }
-
- [Fact]
- public void RemoveOutOfRange()
- {
- var templater = new LLamaTemplate(_model);
-
- Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0));
-
- templater.Add("assistant", "1");
- templater.Add("user", "2");
-
- Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
- Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
- }
- }
|