using System.Text; using LLama.Common; using LLama.Native; namespace LLama.Unittest; public sealed class TemplateTests : IDisposable { private readonly LLamaWeights _model; public TemplateTests() { var @params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 1, GpuLayerCount = Constants.CIGpuLayerCount }; _model = LLamaWeights.LoadFromFile(@params); } public void Dispose() { _model.Dispose(); } [Fact] public void BasicTemplate() { var templater = new LLamaTemplate(_model); Assert.Equal(0, templater.Count); templater.Add("assistant", "hello"); Assert.Equal(1, templater.Count); templater.Add("user", "world"); Assert.Equal(2, templater.Count); templater.Add("assistant", "111"); Assert.Equal(3, templater.Count); templater.Add("user", "aaa"); Assert.Equal(4, templater.Count); templater.Add("assistant", "222"); Assert.Equal(5, templater.Count); templater.Add("user", "bbb"); Assert.Equal(6, templater.Count); templater.Add("assistant", "333"); Assert.Equal(7, templater.Count); templater.Add("user", "ccc"); Assert.Equal(8, templater.Count); // Call once with empty array to discover length var length = templater.Apply(Array.Empty()); var dest = new byte[length]; Assert.Equal(8, templater.Count); // Call again to get contents length = templater.Apply(dest); Assert.Equal(8, templater.Count); var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + "111<|im_end|>" + "\n<|im_start|>user\n" + "aaa<|im_end|>\n" + "<|im_start|>assistant\n" + "222<|im_end|>\n" + "<|im_start|>user\n" + "bbb<|im_end|>\n" + "<|im_start|>assistant\n" + "333<|im_end|>\n" + "<|im_start|>user\n" + "ccc<|im_end|>\n"; Assert.Equal(expected, templateResult); } [Fact] public void CustomTemplate() { var templater = new LLamaTemplate("gemma"); Assert.Equal(0, templater.Count); templater.Add("assistant", "hello"); Assert.Equal(1, templater.Count); templater.Add("user", "world"); Assert.Equal(2, templater.Count); templater.Add("assistant", "111"); Assert.Equal(3, templater.Count); templater.Add("user", "aaa"); Assert.Equal(4, templater.Count); // Call once with empty array to discover length var length = templater.Apply(Array.Empty()); var dest = new byte[length]; Assert.Equal(4, templater.Count); // Call again to get contents length = templater.Apply(dest); Assert.Equal(4, templater.Count); var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); const string expected = "model\n" + "hello\n" + "user\n" + "world\n" + "model\n" + "111\n" + "user\n" + "aaa\n"; Assert.Equal(expected, templateResult); } [Fact] public void BasicTemplateWithAddAssistant() { var templater = new LLamaTemplate(_model) { AddAssistant = true, }; Assert.Equal(0, templater.Count); templater.Add("assistant", "hello"); Assert.Equal(1, templater.Count); templater.Add("user", "world"); Assert.Equal(2, templater.Count); templater.Add("assistant", "111"); Assert.Equal(3, templater.Count); templater.Add("user", "aaa"); Assert.Equal(4, templater.Count); templater.Add("assistant", "222"); Assert.Equal(5, templater.Count); templater.Add("user", "bbb"); Assert.Equal(6, templater.Count); templater.Add("assistant", "333"); Assert.Equal(7, templater.Count); templater.Add("user", "ccc"); Assert.Equal(8, templater.Count); // Call once with empty array to discover length var length = templater.Apply(Array.Empty()); var dest = new byte[length]; Assert.Equal(8, templater.Count); // Call again to get contents length = templater.Apply(dest); Assert.Equal(8, templater.Count); var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + "111<|im_end|>" + "\n<|im_start|>user\n" + "aaa<|im_end|>\n" + "<|im_start|>assistant\n" + "222<|im_end|>\n" + "<|im_start|>user\n" + "bbb<|im_end|>\n" + "<|im_start|>assistant\n" + "333<|im_end|>\n" + "<|im_start|>user\n" + "ccc<|im_end|>\n" + "<|im_start|>assistant\n"; Assert.Equal(expected, templateResult); } [Fact] public void GetOutOfRangeThrows() { var templater = new LLamaTemplate(_model); Assert.Throws(() => templater[0]); templater.Add("assistant", "1"); templater.Add("user", "2"); Assert.Throws(() => templater[-1]); Assert.Throws(() => templater[2]); } [Fact] public void RemoveMid() { var templater = new LLamaTemplate(_model); templater.Add("assistant", "1"); templater.Add("user", "2"); templater.Add("assistant", "3"); templater.Add("user", "4a"); templater.Add("user", "4b"); templater.Add("assistant", "5"); Assert.Equal("user", templater[3].Role); Assert.Equal("4a", templater[3].Content); Assert.Equal("assistant", templater[5].Role); Assert.Equal("5", templater[5].Content); Assert.Equal(6, templater.Count); templater.RemoveAt(3); Assert.Equal(5, templater.Count); Assert.Equal("user", templater[3].Role); Assert.Equal("4b", templater[3].Content); Assert.Equal("assistant", templater[4].Role); Assert.Equal("5", templater[4].Content); } [Fact] public void RemoveLast() { var templater = new LLamaTemplate(_model); templater.Add("assistant", "1"); templater.Add("user", "2"); templater.Add("assistant", "3"); templater.Add("user", "4a"); templater.Add("user", "4b"); templater.Add("assistant", "5"); Assert.Equal(6, templater.Count); templater.RemoveAt(5); Assert.Equal(5, templater.Count); Assert.Equal("user", templater[4].Role); Assert.Equal("4b", templater[4].Content); } [Fact] public void RemoveOutOfRange() { var templater = new LLamaTemplate(_model); Assert.Throws(() => templater.RemoveAt(0)); templater.Add("assistant", "1"); templater.Add("user", "2"); Assert.Throws(() => templater.RemoveAt(-1)); Assert.Throws(() => templater.RemoveAt(2)); } }