using System; using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Text; using LLama.Native; namespace LLama; /// /// Converts a sequence of messages into text according to a model template /// public sealed class LLamaTemplate { #region private state private static readonly Encoding Encoding = Encoding.UTF8; /// /// The model this template is for. May be null if a custom template was supplied to the constructor. /// private readonly SafeLlamaModelHandle? _model; /// /// Custom template. May be null if a model was supplied to the constructor. /// private readonly byte[]? _customTemplate; /// /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. /// private readonly Dictionary> _roleCache = new(); /// /// Array of messages. The property indicates how many messages there are /// private TextMessage?[] _messages = new TextMessage[4]; /// /// Backing field for /// private bool _addAssistant; /// /// Temporary array of messages in the format llama.cpp needs, used when applying the template /// private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4]; /// /// Indicates how many bytes are in array /// private int _resultLength; /// /// Result bytes of last call to /// private byte[] _result = Array.Empty(); /// /// Indicates if this template has been modified and needs regenerating /// private bool _dirty = true; #endregion #region properties /// /// Number of messages added to this template /// public int Count { get; private set; } /// /// Get the message at the given index /// /// /// /// Thrown if index is less than zero or greater than or equal to public TextMessage this[int index] { get { if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0"); if (index >= Count) throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); return _messages[index]!; } } /// /// Whether to end the prompt with the token(s) that indicate the start of an assistant message. /// public bool AddAssistant { get => _addAssistant; set { if (value != _addAssistant) { _dirty = true; _addAssistant = value; } } } #endregion #region construction /// /// Construct a new template, using the default model template /// /// public LLamaTemplate(SafeLlamaModelHandle model) { _model = model; } /// /// Construct a new template, using the default model template /// /// public LLamaTemplate(LLamaWeights weights) : this(weights.NativeHandle) { } /// /// Construct a new template, using a custom template. /// /// Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// public LLamaTemplate(string customTemplate) { _customTemplate = Encoding.GetBytes(customTemplate + "\0"); } #endregion #region modify /// /// Add a new message to the end of this template /// /// /// /// This template, for chaining calls. public LLamaTemplate Add(string role, string content) { return Add(new TextMessage(role, content, _roleCache)); } /// /// Add a new message to the end of this template /// /// /// This template, for chaining calls. public LLamaTemplate Add(TextMessage message) { // Expand messages array if necessary if (Count == _messages.Length) Array.Resize(ref _messages, _messages.Length * 2); // Add message _messages[Count] = message; Count++; // Mark as dirty to ensure template is recalculated _dirty = true; return this; } /// /// Remove a message at the given index /// /// /// This template, for chaining calls. public LLamaTemplate RemoveAt(int index) { if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); if (index >= Count) throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count"); _dirty = true; Count--; // Copy all items after index down by one if (index < Count) Array.Copy(_messages, index + 1, _messages, index, Count - index); _messages[Count] = default; return this; } #endregion /// /// Apply the template to the messages and write it into the output buffer /// /// Destination to write template bytes into /// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer public int Apply(Memory dest) { // Recalculate template if necessary if (_dirty) { _dirty = false; using var group = new GroupDisposable(); unsafe { // Convert all the messages var totalInputBytes = 0; if (_nativeChatMessages.Length < _messages.Length) Array.Resize(ref _nativeChatMessages, _messages.Length); for (var i = 0; i < Count; i++) { ref var m = ref _messages[i]!; Debug.Assert(m != null); totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; // Pin byte arrays in place var r = m.RoleBytes.Pin(); group.Add(r); var c = m.ContentBytes.Pin(); group.Add(c); _nativeChatMessages[i] = new LLamaChatMessage { role = (byte*)r.Pointer, content = (byte*)c.Pointer }; } // Get an array that's twice as large as the amount of input, hopefully that's large enough! var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2)); try { // Run templater and discover true length var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); // If length was too big for output buffer run it again if (outputLength > output.Length) { // Array was too small, rent another one that's exactly the size needed ArrayPool.Shared.Return(output, true); output = ArrayPool.Shared.Rent(outputLength); // Run again, but this time with an output that is definitely large enough ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); } // Grow result buffer if necessary if (_result.Length < outputLength) Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength)); // Copy to result buffer output.AsSpan(0, outputLength).CopyTo(_result); _resultLength = outputLength; } finally { ArrayPool.Shared.Return(output, true); } } } // Now that the template has been applied and is in the result buffer, copy it to the dest _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); return _resultLength; unsafe int ApplyInternal(Span messages, byte[] output) { fixed (byte* customTemplatePtr = _customTemplate) fixed (byte* outputPtr = output) fixed (LLamaChatMessage* messagesPtr = messages) { return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length); } } } /// /// A message that has been added to a template /// public sealed class TextMessage { /// /// The "role" string for this message /// public string Role { get; } /// /// The text content of this message /// public string Content { get; } internal ReadOnlyMemory RoleBytes { get; } internal ReadOnlyMemory ContentBytes { get; } internal TextMessage(string role, string content, IDictionary> roleCache) { Role = role; Content = content; // Get bytes for role from cache if (!roleCache.TryGetValue(role, out var roleBytes)) { // Convert role. Add one to length so there is a null byte at the end. var rArr = new byte[Encoding.GetByteCount(role) + 1]; var encodedRoleLength = Encoding.GetBytes(role.AsSpan(), rArr); Debug.Assert(rArr.Length == encodedRoleLength + 1); // Add to cache for future use. // To ensure the cache cannot grow infinitely add a hard limit to size. if (roleCache.Count < 128) { roleCache.Add(role, rArr); roleBytes = rArr; } } RoleBytes = roleBytes; // Convert content. Add one to length so there is a null byte at the end. var contentArray = new byte[Encoding.GetByteCount(content) + 1]; var encodedContentLength = Encoding.GetBytes(content.AsSpan(), contentArray); Debug.Assert(contentArray.Length == encodedContentLength + 1); ContentBytes = contentArray; } /// /// Deconstruct this message into role and content /// /// /// public void Deconstruct(out string role, out string content) { role = Role; content = Content; } } }