- Fixed `llama_chat_apply_template` method (wrong entrypoint, couldn't handle null model)pull/715/head
| @@ -0,0 +1,245 @@ | |||||
| using System.Text; | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| namespace LLama.Unittest; | |||||
| public sealed class TemplateTests | |||||
| : IDisposable | |||||
| { | |||||
| private readonly LLamaWeights _model; | |||||
| public TemplateTests() | |||||
| { | |||||
| var @params = new ModelParams(Constants.GenerativeModelPath) | |||||
| { | |||||
| ContextSize = 1, | |||||
| GpuLayerCount = Constants.CIGpuLayerCount | |||||
| }; | |||||
| _model = LLamaWeights.LoadFromFile(@params); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _model.Dispose(); | |||||
| } | |||||
| [Fact] | |||||
| public void BasicTemplate() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model); | |||||
| Assert.Equal(0, templater.Count); | |||||
| templater.Add("assistant", "hello"); | |||||
| Assert.Equal(1, templater.Count); | |||||
| templater.Add("user", "world"); | |||||
| Assert.Equal(2, templater.Count); | |||||
| templater.Add("assistant", "111"); | |||||
| Assert.Equal(3, templater.Count); | |||||
| templater.Add("user", "aaa"); | |||||
| Assert.Equal(4, templater.Count); | |||||
| templater.Add("assistant", "222"); | |||||
| Assert.Equal(5, templater.Count); | |||||
| templater.Add("user", "bbb"); | |||||
| Assert.Equal(6, templater.Count); | |||||
| templater.Add("assistant", "333"); | |||||
| Assert.Equal(7, templater.Count); | |||||
| templater.Add("user", "ccc"); | |||||
| Assert.Equal(8, templater.Count); | |||||
| // Call once with empty array to discover length | |||||
| var length = templater.Apply(Array.Empty<byte>()); | |||||
| var dest = new byte[length]; | |||||
| Assert.Equal(8, templater.Count); | |||||
| // Call again to get contents | |||||
| length = templater.Apply(dest); | |||||
| Assert.Equal(8, templater.Count); | |||||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||||
| const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + | |||||
| "<|im_start|>user\nworld<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "111<|im_end|>" + | |||||
| "\n<|im_start|>user\n" + | |||||
| "aaa<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "222<|im_end|>\n" + | |||||
| "<|im_start|>user\n" + | |||||
| "bbb<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "333<|im_end|>\n" + | |||||
| "<|im_start|>user\n" + | |||||
| "ccc<|im_end|>\n"; | |||||
| Assert.Equal(expected, templateResult); | |||||
| } | |||||
| [Fact] | |||||
| public void CustomTemplate() | |||||
| { | |||||
| var templater = new LLamaTemplate("gemma"); | |||||
| Assert.Equal(0, templater.Count); | |||||
| templater.Add("assistant", "hello"); | |||||
| Assert.Equal(1, templater.Count); | |||||
| templater.Add("user", "world"); | |||||
| Assert.Equal(2, templater.Count); | |||||
| templater.Add("assistant", "111"); | |||||
| Assert.Equal(3, templater.Count); | |||||
| templater.Add("user", "aaa"); | |||||
| Assert.Equal(4, templater.Count); | |||||
| // Call once with empty array to discover length | |||||
| var length = templater.Apply(Array.Empty<byte>()); | |||||
| var dest = new byte[length]; | |||||
| Assert.Equal(4, templater.Count); | |||||
| // Call again to get contents | |||||
| length = templater.Apply(dest); | |||||
| Assert.Equal(4, templater.Count); | |||||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||||
| const string expected = "<start_of_turn>model\n" + | |||||
| "hello<end_of_turn>\n" + | |||||
| "<start_of_turn>user\n" + | |||||
| "world<end_of_turn>\n" + | |||||
| "<start_of_turn>model\n" + | |||||
| "111<end_of_turn>\n" + | |||||
| "<start_of_turn>user\n" + | |||||
| "aaa<end_of_turn>\n"; | |||||
| Assert.Equal(expected, templateResult); | |||||
| } | |||||
| [Fact] | |||||
| public void BasicTemplateWithAddAssistant() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model) | |||||
| { | |||||
| AddAssistant = true, | |||||
| }; | |||||
| Assert.Equal(0, templater.Count); | |||||
| templater.Add("assistant", "hello"); | |||||
| Assert.Equal(1, templater.Count); | |||||
| templater.Add("user", "world"); | |||||
| Assert.Equal(2, templater.Count); | |||||
| templater.Add("assistant", "111"); | |||||
| Assert.Equal(3, templater.Count); | |||||
| templater.Add("user", "aaa"); | |||||
| Assert.Equal(4, templater.Count); | |||||
| templater.Add("assistant", "222"); | |||||
| Assert.Equal(5, templater.Count); | |||||
| templater.Add("user", "bbb"); | |||||
| Assert.Equal(6, templater.Count); | |||||
| templater.Add("assistant", "333"); | |||||
| Assert.Equal(7, templater.Count); | |||||
| templater.Add("user", "ccc"); | |||||
| Assert.Equal(8, templater.Count); | |||||
| // Call once with empty array to discover length | |||||
| var length = templater.Apply(Array.Empty<byte>()); | |||||
| var dest = new byte[length]; | |||||
| Assert.Equal(8, templater.Count); | |||||
| // Call again to get contents | |||||
| length = templater.Apply(dest); | |||||
| Assert.Equal(8, templater.Count); | |||||
| var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); | |||||
| const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + | |||||
| "<|im_start|>user\nworld<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "111<|im_end|>" + | |||||
| "\n<|im_start|>user\n" + | |||||
| "aaa<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "222<|im_end|>\n" + | |||||
| "<|im_start|>user\n" + | |||||
| "bbb<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n" + | |||||
| "333<|im_end|>\n" + | |||||
| "<|im_start|>user\n" + | |||||
| "ccc<|im_end|>\n" + | |||||
| "<|im_start|>assistant\n"; | |||||
| Assert.Equal(expected, templateResult); | |||||
| } | |||||
| [Fact] | |||||
| public void GetOutOfRangeThrows() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]); | |||||
| templater.Add("assistant", "1"); | |||||
| templater.Add("user", "2"); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]); | |||||
| } | |||||
| [Fact] | |||||
| public void RemoveMid() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model); | |||||
| templater.Add("assistant", "1"); | |||||
| templater.Add("user", "2"); | |||||
| templater.Add("assistant", "3"); | |||||
| templater.Add("user", "4a"); | |||||
| templater.Add("user", "4b"); | |||||
| templater.Add("assistant", "5"); | |||||
| Assert.Equal(("user", "4a"), templater[3]); | |||||
| Assert.Equal(("assistant", "5"), templater[5]); | |||||
| Assert.Equal(6, templater.Count); | |||||
| templater.RemoveAt(3); | |||||
| Assert.Equal(5, templater.Count); | |||||
| Assert.Equal(("user", "4b"), templater[3]); | |||||
| Assert.Equal(("assistant", "5"), templater[4]); | |||||
| } | |||||
| [Fact] | |||||
| public void RemoveLast() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model); | |||||
| templater.Add("assistant", "1"); | |||||
| templater.Add("user", "2"); | |||||
| templater.Add("assistant", "3"); | |||||
| templater.Add("user", "4a"); | |||||
| templater.Add("user", "4b"); | |||||
| templater.Add("assistant", "5"); | |||||
| Assert.Equal(6, templater.Count); | |||||
| templater.RemoveAt(5); | |||||
| Assert.Equal(5, templater.Count); | |||||
| Assert.Equal(("user", "4b"), templater[4]); | |||||
| } | |||||
| [Fact] | |||||
| public void RemoveOutOfRange() | |||||
| { | |||||
| var templater = new LLamaTemplate(_model); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0)); | |||||
| templater.Add("assistant", "1"); | |||||
| templater.Add("user", "2"); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1)); | |||||
| Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2)); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,301 @@ | |||||
| using System; | |||||
| using System.Buffers; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| using LLama.Native; | |||||
| namespace LLama; | |||||
| /// <summary> | |||||
| /// Converts a sequence of messages into text according to a model template | |||||
| /// </summary> | |||||
| public sealed class LLamaTemplate | |||||
| { | |||||
| #region private state | |||||
| /// <summary> | |||||
| /// The model this template is for. May be null if a custom template was supplied to the constructor. | |||||
| /// </summary> | |||||
| private readonly SafeLlamaModelHandle? _model; | |||||
| /// <summary> | |||||
| /// Custom template. May be null if a model was supplied to the constructor. | |||||
| /// </summary> | |||||
| private readonly byte[]? _customTemplate; | |||||
| /// <summary> | |||||
| /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. | |||||
| /// </summary> | |||||
| private readonly Dictionary<string, ReadOnlyMemory<byte>> _roleCache = new(); | |||||
| /// <summary> | |||||
| /// Array of messages. The <see cref="Count"/> property indicates how many messages there are | |||||
| /// </summary> | |||||
| private Message[] _messages = new Message[4]; | |||||
| /// <summary> | |||||
| /// Backing field for <see cref="AddAssistant"/> | |||||
| /// </summary> | |||||
| private bool _addAssistant; | |||||
| /// <summary> | |||||
| /// Temporary array of messages in the format llama.cpp needs, used when applying the template | |||||
| /// </summary> | |||||
| private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4]; | |||||
| /// <summary> | |||||
| /// Indicates how many bytes are in <see cref="_result"/> array | |||||
| /// </summary> | |||||
| private int _resultLength; | |||||
| /// <summary> | |||||
| /// Result bytes of last call to <see cref="Apply"/> | |||||
| /// </summary> | |||||
| private byte[] _result = Array.Empty<byte>(); | |||||
| /// <summary> | |||||
| /// Indicates if this template has been modified and needs regenerating | |||||
| /// </summary> | |||||
| private bool _dirty = true; | |||||
| #endregion | |||||
| #region properties | |||||
| /// <summary> | |||||
| /// Number of messages added to this template | |||||
| /// </summary> | |||||
| public int Count { get; private set; } | |||||
| /// <summary> | |||||
| /// Get the message at the given index | |||||
| /// </summary> | |||||
| /// <param name="index"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception> | |||||
| public (string role, string content) this[int index] | |||||
| { | |||||
| get | |||||
| { | |||||
| if (index < 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0"); | |||||
| if (index >= Count) | |||||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); | |||||
| return (_messages[index].Role, _messages[index].Content); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Whether to end the prompt with the token(s) that indicate the start of an assistant message. | |||||
| /// </summary> | |||||
| public bool AddAssistant | |||||
| { | |||||
| get => _addAssistant; | |||||
| set | |||||
| { | |||||
| if (value != _addAssistant) | |||||
| { | |||||
| _dirty = true; | |||||
| _addAssistant = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| #region construction | |||||
| /// <summary> | |||||
| /// Construct a new template, using the default model template | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| public LLamaTemplate(SafeLlamaModelHandle model) | |||||
| { | |||||
| _model = model; | |||||
| } | |||||
| /// <summary> | |||||
| /// Construct a new template, using the default model template | |||||
| /// </summary> | |||||
| /// <param name="weights"></param> | |||||
| public LLamaTemplate(LLamaWeights weights) | |||||
| : this(weights.NativeHandle) | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Construct a new template, using a custom template. | |||||
| /// </summary> | |||||
| /// <remarks>Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template</remarks> | |||||
| /// <param name="customTemplate"></param> | |||||
| public LLamaTemplate(string customTemplate) | |||||
| { | |||||
| _customTemplate = Encoding.UTF8.GetBytes(customTemplate + "\0"); | |||||
| } | |||||
| #endregion | |||||
| /// <summary> | |||||
| /// Add a new message to the end of this template | |||||
| /// </summary> | |||||
| /// <param name="role"></param> | |||||
| /// <param name="content"></param> | |||||
| public void Add(string role, string content) | |||||
| { | |||||
| // Expand messages array if necessary | |||||
| if (Count == _messages.Length) | |||||
| Array.Resize(ref _messages, _messages.Length * 2); | |||||
| // Add message | |||||
| _messages[Count] = new Message(role, content, _roleCache); | |||||
| Count++; | |||||
| // Mark as dirty to ensure template is recalculated | |||||
| _dirty = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Remove a message at the given index | |||||
| /// </summary> | |||||
| /// <param name="index"></param> | |||||
| public void RemoveAt(int index) | |||||
| { | |||||
| if (index < 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); | |||||
| if (index >= Count) | |||||
| throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count"); | |||||
| _dirty = true; | |||||
| Count--; | |||||
| // Copy all items after index down by one | |||||
| if (index < Count) | |||||
| Array.Copy(_messages, index + 1, _messages, index, Count - index); | |||||
| _messages[Count] = default; | |||||
| } | |||||
| /// <summary> | |||||
| /// Apply the template to the messages and write it into the output buffer | |||||
| /// </summary> | |||||
| /// <param name="dest">Destination to write template bytes into</param> | |||||
| /// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns> | |||||
| public int Apply(Memory<byte> dest) | |||||
| { | |||||
| // Recalculate template if necessary | |||||
| if (_dirty) | |||||
| { | |||||
| _dirty = false; | |||||
| using var group = new GroupDisposable(); | |||||
| unsafe | |||||
| { | |||||
| // Convert all the messages | |||||
| var totalInputBytes = 0; | |||||
| if (_nativeChatMessages.Length < _messages.Length) | |||||
| Array.Resize(ref _nativeChatMessages, _messages.Length); | |||||
| for (var i = 0; i < Count; i++) | |||||
| { | |||||
| ref var m = ref _messages[i]; | |||||
| totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; | |||||
| // Pin byte arrays in place | |||||
| var r = m.RoleBytes.Pin(); | |||||
| group.Add(r); | |||||
| var c = m.ContentBytes.Pin(); | |||||
| group.Add(c); | |||||
| _nativeChatMessages[i] = new LLamaChatMessage | |||||
| { | |||||
| role = (byte*)r.Pointer, | |||||
| content = (byte*)c.Pointer | |||||
| }; | |||||
| } | |||||
| // Get an array that's twice as large as the amount of input, hopefully that's large enough! | |||||
| var output = ArrayPool<byte>.Shared.Rent(Math.Max(32, totalInputBytes * 2)); | |||||
| try | |||||
| { | |||||
| // Run templater and discover true length | |||||
| var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); | |||||
| // If length was too big for output buffer run it again | |||||
| if (outputLength > output.Length) | |||||
| { | |||||
| // Array was too small, rent another one that's exactly the size needed | |||||
| ArrayPool<byte>.Shared.Return(output, true); | |||||
| output = ArrayPool<byte>.Shared.Rent(outputLength); | |||||
| // Run again, but this time with an output that is definitely large enough | |||||
| ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); | |||||
| } | |||||
| // Grow result buffer if necessary | |||||
| if (_result.Length < outputLength) | |||||
| Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength)); | |||||
| // Copy to result buffer | |||||
| output.AsSpan(0, outputLength).CopyTo(_result); | |||||
| _resultLength = outputLength; | |||||
| } | |||||
| finally | |||||
| { | |||||
| ArrayPool<byte>.Shared.Return(output, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Now that the template has been applied and is in the result buffer, copy it to the dest | |||||
| _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); | |||||
| return _resultLength; | |||||
| unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output) | |||||
| { | |||||
| fixed (byte* customTemplatePtr = _customTemplate) | |||||
| fixed (byte* outputPtr = output) | |||||
| fixed (LLamaChatMessage* messagesPtr = messages) | |||||
| { | |||||
| return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// A message that has been added to the template, contains role and content converted into UTF8 bytes. | |||||
| /// </summary> | |||||
| private readonly record struct Message | |||||
| { | |||||
| public string Role { get; } | |||||
| public string Content { get; } | |||||
| public ReadOnlyMemory<byte> RoleBytes { get; } | |||||
| public ReadOnlyMemory<byte> ContentBytes { get; } | |||||
| public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache) | |||||
| { | |||||
| Role = role; | |||||
| Content = content; | |||||
| // Get bytes for role from cache | |||||
| if (!roleCache.TryGetValue(role, out var roleBytes)) | |||||
| { | |||||
| // Convert role. Add one to length so there is a null byte at the end. | |||||
| var rArr = new byte[Encoding.UTF8.GetByteCount(role) + 1]; | |||||
| var encodedRoleLength = Encoding.UTF8.GetBytes(role.AsSpan(), rArr); | |||||
| Debug.Assert(rArr.Length == encodedRoleLength + 1); | |||||
| // Add to cache for future use. | |||||
| // To ensure the cache cannot grow infinitely add a hard limit to size. | |||||
| if (roleCache.Count < 128) | |||||
| { | |||||
| roleCache.Add(role, rArr); | |||||
| roleBytes = rArr; | |||||
| } | |||||
| } | |||||
| RoleBytes = roleBytes; | |||||
| // Convert content. Add one to length so there is a null byte at the end. | |||||
| var contentArray = new byte[Encoding.UTF8.GetByteCount(content) + 1]; | |||||
| var encodedContentLength = Encoding.UTF8.GetBytes(content.AsSpan(), contentArray); | |||||
| Debug.Assert(contentArray.Length == encodedContentLength + 1); | |||||
| ContentBytes = contentArray; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,11 +1,21 @@ | |||||
| namespace LLama.Native; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace LLama.Native; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>llama_chat_message</remarks> | /// <remarks>llama_chat_message</remarks> | ||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public unsafe struct LLamaChatMessage | public unsafe struct LLamaChatMessage | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Pointer to the null terminated bytes that make up the role string | |||||
| /// </summary> | |||||
| public byte* role; | public byte* role; | ||||
| /// <summary> | |||||
| /// Pointer to the null terminated bytes that make up the content string | |||||
| /// </summary> | |||||
| public byte* content; | public byte* content; | ||||
| } | } | ||||
| @@ -1,4 +1,4 @@ | |||||
| using System; | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| #pragma warning disable IDE1006 // Naming Styles | #pragma warning disable IDE1006 // Naming Styles | ||||
| @@ -174,8 +174,13 @@ namespace LLama.Native | |||||
| /// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param> | /// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param> | ||||
| /// <param name="length">The size of the allocated buffer</param> | /// <param name="length">The size of the allocated buffer</param> | ||||
| /// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns> | /// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] | |||||
| public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length); | |||||
| public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length) | |||||
| { | |||||
| return internal_llama_chat_apply_template(model?.DangerousGetHandle() ?? IntPtr.Zero, tmpl, chat, n_msg, add_ass, buf, length); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] | |||||
| static extern int internal_llama_chat_apply_template(IntPtr model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns -1 if unknown, 1 for true or 0 for false. | /// Returns -1 if unknown, 1 for true or 0 for false. | ||||
| @@ -0,0 +1 @@ | |||||
| global using LLama.Extensions; | |||||