- Fixed `llama_chat_apply_template` method (wrong entrypoint, couldn't handle null model)pull/715/head
| @@ -0,0 +1,245 @@ | |||
| using System.Text; | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| namespace LLama.Unittest; | |||
| public sealed class TemplateTests | |||
| : IDisposable | |||
| { | |||
| private readonly LLamaWeights _model; | |||
| public TemplateTests() | |||
| { | |||
| var @params = new ModelParams(Constants.GenerativeModelPath) | |||
| { | |||
| ContextSize = 1, | |||
| GpuLayerCount = Constants.CIGpuLayerCount | |||
| }; | |||
| _model = LLamaWeights.LoadFromFile(@params); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| _model.Dispose(); | |||
| } | |||
| [Fact] | |||
| public void BasicTemplate() | |||
| { | |||
| var templater = new LLamaTemplate(_model); | |||
| Assert.Equal(0, templater.Count); | |||
| templater.Add("assistant", "hello"); | |||
| Assert.Equal(1, templater.Count); | |||
| templater.Add("user", "world"); | |||
| Assert.Equal(2, templater.Count); | |||
| templater.Add("assistant", "111"); | |||
| Assert.Equal(3, templater.Count); | |||
| templater.Add("user", "aaa"); | |||
| Assert.Equal(4, templater.Count); | |||
| templater.Add("assistant", "222"); | |||
| Assert.Equal(5, templater.Count); | |||
| templater.Add("user", "bbb"); | |||
| Assert.Equal(6, templater.Count); | |||
| templater.Add("assistant", "333"); | |||
| Assert.Equal(7, templater.Count); | |||
| templater.Add("user", "ccc"); | |||
| Assert.Equal(8, templater.Count); | |||
| // Call once with empty array to discover length | |||
| var length = templater.Apply(Array.Empty<byte>()); | |||
| var dest = new byte[length]; | |||
| Assert.Equal(8, templater.Count); | |||
| // Call again to get contents | |||
| length = templater.Apply(dest); | |||
| Assert.Equal(8, templater.Count); | |||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||
| const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + | |||
| "<|im_start|>user\nworld<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "111<|im_end|>" + | |||
| "\n<|im_start|>user\n" + | |||
| "aaa<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "222<|im_end|>\n" + | |||
| "<|im_start|>user\n" + | |||
| "bbb<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "333<|im_end|>\n" + | |||
| "<|im_start|>user\n" + | |||
| "ccc<|im_end|>\n"; | |||
| Assert.Equal(expected, templateResult); | |||
| } | |||
| [Fact] | |||
| public void CustomTemplate() | |||
| { | |||
| var templater = new LLamaTemplate("gemma"); | |||
| Assert.Equal(0, templater.Count); | |||
| templater.Add("assistant", "hello"); | |||
| Assert.Equal(1, templater.Count); | |||
| templater.Add("user", "world"); | |||
| Assert.Equal(2, templater.Count); | |||
| templater.Add("assistant", "111"); | |||
| Assert.Equal(3, templater.Count); | |||
| templater.Add("user", "aaa"); | |||
| Assert.Equal(4, templater.Count); | |||
| // Call once with empty array to discover length | |||
| var length = templater.Apply(Array.Empty<byte>()); | |||
| var dest = new byte[length]; | |||
| Assert.Equal(4, templater.Count); | |||
| // Call again to get contents | |||
| length = templater.Apply(dest); | |||
| Assert.Equal(4, templater.Count); | |||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||
| const string expected = "<start_of_turn>model\n" + | |||
| "hello<end_of_turn>\n" + | |||
| "<start_of_turn>user\n" + | |||
| "world<end_of_turn>\n" + | |||
| "<start_of_turn>model\n" + | |||
| "111<end_of_turn>\n" + | |||
| "<start_of_turn>user\n" + | |||
| "aaa<end_of_turn>\n"; | |||
| Assert.Equal(expected, templateResult); | |||
| } | |||
| [Fact] | |||
| public void BasicTemplateWithAddAssistant() | |||
| { | |||
| var templater = new LLamaTemplate(_model) | |||
| { | |||
| AddAssistant = true, | |||
| }; | |||
| Assert.Equal(0, templater.Count); | |||
| templater.Add("assistant", "hello"); | |||
| Assert.Equal(1, templater.Count); | |||
| templater.Add("user", "world"); | |||
| Assert.Equal(2, templater.Count); | |||
| templater.Add("assistant", "111"); | |||
| Assert.Equal(3, templater.Count); | |||
| templater.Add("user", "aaa"); | |||
| Assert.Equal(4, templater.Count); | |||
| templater.Add("assistant", "222"); | |||
| Assert.Equal(5, templater.Count); | |||
| templater.Add("user", "bbb"); | |||
| Assert.Equal(6, templater.Count); | |||
| templater.Add("assistant", "333"); | |||
| Assert.Equal(7, templater.Count); | |||
| templater.Add("user", "ccc"); | |||
| Assert.Equal(8, templater.Count); | |||
| // Call once with empty array to discover length | |||
| var length = templater.Apply(Array.Empty<byte>()); | |||
| var dest = new byte[length]; | |||
| Assert.Equal(8, templater.Count); | |||
| // Call again to get contents | |||
| length = templater.Apply(dest); | |||
| Assert.Equal(8, templater.Count); | |||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||
| const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + | |||
| "<|im_start|>user\nworld<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "111<|im_end|>" + | |||
| "\n<|im_start|>user\n" + | |||
| "aaa<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "222<|im_end|>\n" + | |||
| "<|im_start|>user\n" + | |||
| "bbb<|im_end|>\n" + | |||
| "<|im_start|>assistant\n" + | |||
| "333<|im_end|>\n" + | |||
| "<|im_start|>user\n" + | |||
| "ccc<|im_end|>\n" + | |||
| "<|im_start|>assistant\n"; | |||
| Assert.Equal(expected, templateResult); | |||
| } | |||
| [Fact] | |||
| public void GetOutOfRangeThrows() | |||
| { | |||
| var templater = new LLamaTemplate(_model); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]); | |||
| templater.Add("assistant", "1"); | |||
| templater.Add("user", "2"); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]); | |||
| } | |||
| [Fact] | |||
| public void RemoveMid() | |||
| { | |||
| var templater = new LLamaTemplate(_model); | |||
| templater.Add("assistant", "1"); | |||
| templater.Add("user", "2"); | |||
| templater.Add("assistant", "3"); | |||
| templater.Add("user", "4a"); | |||
| templater.Add("user", "4b"); | |||
| templater.Add("assistant", "5"); | |||
| Assert.Equal(("user", "4a"), templater[3]); | |||
| Assert.Equal(("assistant", "5"), templater[5]); | |||
| Assert.Equal(6, templater.Count); | |||
| templater.RemoveAt(3); | |||
| Assert.Equal(5, templater.Count); | |||
| Assert.Equal(("user", "4b"), templater[3]); | |||
| Assert.Equal(("assistant", "5"), templater[4]); | |||
| } | |||
| [Fact] | |||
| public void RemoveLast() | |||
| { | |||
| var templater = new LLamaTemplate(_model); | |||
| templater.Add("assistant", "1"); | |||
| templater.Add("user", "2"); | |||
| templater.Add("assistant", "3"); | |||
| templater.Add("user", "4a"); | |||
| templater.Add("user", "4b"); | |||
| templater.Add("assistant", "5"); | |||
| Assert.Equal(6, templater.Count); | |||
| templater.RemoveAt(5); | |||
| Assert.Equal(5, templater.Count); | |||
| Assert.Equal(("user", "4b"), templater[4]); | |||
| } | |||
| [Fact] | |||
| public void RemoveOutOfRange() | |||
| { | |||
| var templater = new LLamaTemplate(_model); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0)); | |||
| templater.Add("assistant", "1"); | |||
| templater.Add("user", "2"); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1)); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2)); | |||
| } | |||
| } | |||
| @@ -0,0 +1,301 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using LLama.Native; | |||
| namespace LLama; | |||
| /// <summary> | |||
| /// Converts a sequence of messages into text according to a model template | |||
| /// </summary> | |||
| public sealed class LLamaTemplate | |||
| { | |||
| #region private state | |||
| /// <summary> | |||
| /// The model this template is for. May be null if a custom template was supplied to the constructor. | |||
| /// </summary> | |||
| private readonly SafeLlamaModelHandle? _model; | |||
| /// <summary> | |||
| /// Custom template. May be null if a model was supplied to the constructor. | |||
| /// </summary> | |||
| private readonly byte[]? _customTemplate; | |||
| /// <summary> | |||
| /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. | |||
| /// </summary> | |||
| private readonly Dictionary<string, ReadOnlyMemory<byte>> _roleCache = new(); | |||
| /// <summary> | |||
| /// Array of messages. The <see cref="Count"/> property indicates how many messages there are | |||
| /// </summary> | |||
| private Message[] _messages = new Message[4]; | |||
| /// <summary> | |||
| /// Backing field for <see cref="AddAssistant"/> | |||
| /// </summary> | |||
| private bool _addAssistant; | |||
| /// <summary> | |||
| /// Temporary array of messages in the format llama.cpp needs, used when applying the template | |||
| /// </summary> | |||
| private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4]; | |||
| /// <summary> | |||
| /// Indicates how many bytes are in <see cref="_result"/> array | |||
| /// </summary> | |||
| private int _resultLength; | |||
| /// <summary> | |||
| /// Result bytes of last call to <see cref="Apply"/> | |||
| /// </summary> | |||
| private byte[] _result = Array.Empty<byte>(); | |||
| /// <summary> | |||
| /// Indicates if this template has been modified and needs regenerating | |||
| /// </summary> | |||
| private bool _dirty = true; | |||
| #endregion | |||
| #region properties | |||
| /// <summary> | |||
| /// Number of messages added to this template | |||
| /// </summary> | |||
| public int Count { get; private set; } | |||
| /// <summary> | |||
| /// Get the message at the given index | |||
| /// </summary> | |||
| /// <param name="index"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception> | |||
| public (string role, string content) this[int index] | |||
| { | |||
| get | |||
| { | |||
| if (index < 0) | |||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0"); | |||
| if (index >= Count) | |||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); | |||
| return (_messages[index].Role, _messages[index].Content); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Whether to end the prompt with the token(s) that indicate the start of an assistant message. | |||
| /// </summary> | |||
| public bool AddAssistant | |||
| { | |||
| get => _addAssistant; | |||
| set | |||
| { | |||
| if (value != _addAssistant) | |||
| { | |||
| _dirty = true; | |||
| _addAssistant = value; | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| #region construction | |||
| /// <summary> | |||
| /// Construct a new template, using the default model template | |||
| /// </summary> | |||
| /// <param name="model"></param> | |||
| public LLamaTemplate(SafeLlamaModelHandle model) | |||
| { | |||
| _model = model; | |||
| } | |||
| /// <summary> | |||
| /// Construct a new template, using the default model template | |||
| /// </summary> | |||
| /// <param name="weights"></param> | |||
| public LLamaTemplate(LLamaWeights weights) | |||
| : this(weights.NativeHandle) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Construct a new template, using a custom template. | |||
| /// </summary> | |||
| /// <remarks>Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template</remarks> | |||
| /// <param name="customTemplate"></param> | |||
| public LLamaTemplate(string customTemplate) | |||
| { | |||
| _customTemplate = Encoding.UTF8.GetBytes(customTemplate + "\0"); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Add a new message to the end of this template | |||
| /// </summary> | |||
| /// <param name="role"></param> | |||
| /// <param name="content"></param> | |||
| public void Add(string role, string content) | |||
| { | |||
| // Expand messages array if necessary | |||
| if (Count == _messages.Length) | |||
| Array.Resize(ref _messages, _messages.Length * 2); | |||
| // Add message | |||
| _messages[Count] = new Message(role, content, _roleCache); | |||
| Count++; | |||
| // Mark as dirty to ensure template is recalculated | |||
| _dirty = true; | |||
| } | |||
| /// <summary> | |||
| /// Remove a message at the given index | |||
| /// </summary> | |||
| /// <param name="index"></param> | |||
| public void RemoveAt(int index) | |||
| { | |||
| if (index < 0) | |||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); | |||
| if (index >= Count) | |||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count"); | |||
| _dirty = true; | |||
| Count--; | |||
| // Copy all items after index down by one | |||
| if (index < Count) | |||
| Array.Copy(_messages, index + 1, _messages, index, Count - index); | |||
| _messages[Count] = default; | |||
| } | |||
| /// <summary> | |||
| /// Apply the template to the messages and write it into the output buffer | |||
| /// </summary> | |||
| /// <param name="dest">Destination to write template bytes into</param> | |||
| /// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns> | |||
| public int Apply(Memory<byte> dest) | |||
| { | |||
| // Recalculate template if necessary | |||
| if (_dirty) | |||
| { | |||
| _dirty = false; | |||
| using var group = new GroupDisposable(); | |||
| unsafe | |||
| { | |||
| // Convert all the messages | |||
| var totalInputBytes = 0; | |||
| if (_nativeChatMessages.Length < _messages.Length) | |||
| Array.Resize(ref _nativeChatMessages, _messages.Length); | |||
| for (var i = 0; i < Count; i++) | |||
| { | |||
| ref var m = ref _messages[i]; | |||
| totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; | |||
| // Pin byte arrays in place | |||
| var r = m.RoleBytes.Pin(); | |||
| group.Add(r); | |||
| var c = m.ContentBytes.Pin(); | |||
| group.Add(c); | |||
| _nativeChatMessages[i] = new LLamaChatMessage | |||
| { | |||
| role = (byte*)r.Pointer, | |||
| content = (byte*)c.Pointer | |||
| }; | |||
| } | |||
| // Get an array that's twice as large as the amount of input, hopefully that's large enough! | |||
| var output = ArrayPool<byte>.Shared.Rent(Math.Max(32, totalInputBytes * 2)); | |||
| try | |||
| { | |||
| // Run templater and discover true length | |||
| var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); | |||
| // If length was too big for output buffer run it again | |||
| if (outputLength > output.Length) | |||
| { | |||
| // Array was too small, rent another one that's exactly the size needed | |||
| ArrayPool<byte>.Shared.Return(output, true); | |||
| output = ArrayPool<byte>.Shared.Rent(outputLength); | |||
| // Run again, but this time with an output that is definitely large enough | |||
| ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); | |||
| } | |||
| // Grow result buffer if necessary | |||
| if (_result.Length < outputLength) | |||
| Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength)); | |||
| // Copy to result buffer | |||
| output.AsSpan(0, outputLength).CopyTo(_result); | |||
| _resultLength = outputLength; | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<byte>.Shared.Return(output, true); | |||
| } | |||
| } | |||
| } | |||
| // Now that the template has been applied and is in the result buffer, copy it to the dest | |||
| _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); | |||
| return _resultLength; | |||
| unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output) | |||
| { | |||
| fixed (byte* customTemplatePtr = _customTemplate) | |||
| fixed (byte* outputPtr = output) | |||
| fixed (LLamaChatMessage* messagesPtr = messages) | |||
| { | |||
| return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// A message that has been added to the template, contains role and content converted into UTF8 bytes. | |||
| /// </summary> | |||
| private readonly record struct Message | |||
| { | |||
| public string Role { get; } | |||
| public string Content { get; } | |||
| public ReadOnlyMemory<byte> RoleBytes { get; } | |||
| public ReadOnlyMemory<byte> ContentBytes { get; } | |||
| public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache) | |||
| { | |||
| Role = role; | |||
| Content = content; | |||
| // Get bytes for role from cache | |||
| if (!roleCache.TryGetValue(role, out var roleBytes)) | |||
| { | |||
| // Convert role. Add one to length so there is a null byte at the end. | |||
| var rArr = new byte[Encoding.UTF8.GetByteCount(role) + 1]; | |||
| var encodedRoleLength = Encoding.UTF8.GetBytes(role.AsSpan(), rArr); | |||
| Debug.Assert(rArr.Length == encodedRoleLength + 1); | |||
| // Add to cache for future use. | |||
| // To ensure the cache cannot grow infinitely add a hard limit to size. | |||
| if (roleCache.Count < 128) | |||
| { | |||
| roleCache.Add(role, rArr); | |||
| roleBytes = rArr; | |||
| } | |||
| } | |||
| RoleBytes = roleBytes; | |||
| // Convert content. Add one to length so there is a null byte at the end. | |||
| var contentArray = new byte[Encoding.UTF8.GetByteCount(content) + 1]; | |||
| var encodedContentLength = Encoding.UTF8.GetBytes(content.AsSpan(), contentArray); | |||
| Debug.Assert(contentArray.Length == encodedContentLength + 1); | |||
| ContentBytes = contentArray; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,11 +1,21 @@ | |||
| namespace LLama.Native; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <remarks>llama_chat_message</remarks> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public unsafe struct LLamaChatMessage | |||
| { | |||
| /// <summary> | |||
| /// Pointer to the null terminated bytes that make up the role string | |||
| /// </summary> | |||
| public byte* role; | |||
| /// <summary> | |||
| /// Pointer to the null terminated bytes that make up the content string | |||
| /// </summary> | |||
| public byte* content; | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| using System; | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| #pragma warning disable IDE1006 // Naming Styles | |||
| @@ -174,8 +174,13 @@ namespace LLama.Native | |||
| /// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param> | |||
| /// <param name="length">The size of the allocated buffer</param> | |||
| /// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||
| public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length); | |||
| public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length) | |||
| { | |||
| return internal_llama_chat_apply_template(model?.DangerousGetHandle() ?? IntPtr.Zero, tmpl, chat, n_msg, add_ass, buf, length); | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] | |||
| static extern int internal_llama_chat_apply_template(IntPtr model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length); | |||
| } | |||
| /// <summary> | |||
| /// Returns -1 if unknown, 1 for true or 0 for false. | |||
| @@ -0,0 +1 @@ | |||
| global using LLama.Extensions; | |||