using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using LLama.Native;
namespace LLama;
///
/// Converts a sequence of messages into text according to a model template
///
public sealed class LLamaTemplate
{
#region private state
private static readonly Encoding Encoding = Encoding.UTF8;
///
/// The model this template is for. May be null if a custom template was supplied to the constructor.
///
private readonly SafeLlamaModelHandle? _model;
///
/// Custom template. May be null if a model was supplied to the constructor.
///
private readonly byte[]? _customTemplate;
///
/// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
///
private readonly Dictionary> _roleCache = new();
///
/// Array of messages. The property indicates how many messages there are
///
private TextMessage?[] _messages = new TextMessage[4];
///
/// Backing field for
///
private bool _addAssistant;
///
/// Temporary array of messages in the format llama.cpp needs, used when applying the template
///
private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4];
///
/// Indicates how many bytes are in array
///
private int _resultLength;
///
/// Result bytes of last call to
///
private byte[] _result = Array.Empty();
///
/// Indicates if this template has been modified and needs regenerating
///
private bool _dirty = true;
#endregion
#region properties
///
/// Number of messages added to this template
///
public int Count { get; private set; }
///
/// Get the message at the given index
///
///
///
/// Thrown if index is less than zero or greater than or equal to
public TextMessage this[int index]
{
get
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0");
if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");
return _messages[index]!;
}
}
///
/// Whether to end the prompt with the token(s) that indicate the start of an assistant message.
///
public bool AddAssistant
{
get => _addAssistant;
set
{
if (value != _addAssistant)
{
_dirty = true;
_addAssistant = value;
}
}
}
#endregion
#region construction
///
/// Construct a new template, using the default model template
///
///
public LLamaTemplate(SafeLlamaModelHandle model)
{
_model = model;
}
///
/// Construct a new template, using the default model template
///
///
public LLamaTemplate(LLamaWeights weights)
: this(weights.NativeHandle)
{
}
///
/// Construct a new template, using a custom template.
///
/// Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
public LLamaTemplate(string customTemplate)
{
_customTemplate = Encoding.GetBytes(customTemplate + "\0");
}
#endregion
#region modify
///
/// Add a new message to the end of this template
///
///
///
/// This template, for chaining calls.
public LLamaTemplate Add(string role, string content)
{
return Add(new TextMessage(role, content, _roleCache));
}
///
/// Add a new message to the end of this template
///
///
/// This template, for chaining calls.
public LLamaTemplate Add(TextMessage message)
{
// Expand messages array if necessary
if (Count == _messages.Length)
Array.Resize(ref _messages, _messages.Length * 2);
// Add message
_messages[Count] = message;
Count++;
// Mark as dirty to ensure template is recalculated
_dirty = true;
return this;
}
///
/// Remove a message at the given index
///
///
/// This template, for chaining calls.
public LLamaTemplate RemoveAt(int index)
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count");
_dirty = true;
Count--;
// Copy all items after index down by one
if (index < Count)
Array.Copy(_messages, index + 1, _messages, index, Count - index);
_messages[Count] = default;
return this;
}
#endregion
///
/// Apply the template to the messages and write it into the output buffer
///
/// Destination to write template bytes into
/// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer
public int Apply(Memory dest)
{
// Recalculate template if necessary
if (_dirty)
{
_dirty = false;
using var group = new GroupDisposable();
unsafe
{
// Convert all the messages
var totalInputBytes = 0;
if (_nativeChatMessages.Length < _messages.Length)
Array.Resize(ref _nativeChatMessages, _messages.Length);
for (var i = 0; i < Count; i++)
{
ref var m = ref _messages[i]!;
Debug.Assert(m != null);
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
// Pin byte arrays in place
var r = m.RoleBytes.Pin();
group.Add(r);
var c = m.ContentBytes.Pin();
group.Add(c);
_nativeChatMessages[i] = new LLamaChatMessage
{
role = (byte*)r.Pointer,
content = (byte*)c.Pointer
};
}
// Get an array that's twice as large as the amount of input, hopefully that's large enough!
var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2));
try
{
// Run templater and discover true length
var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
// If length was too big for output buffer run it again
if (outputLength > output.Length)
{
// Array was too small, rent another one that's exactly the size needed
ArrayPool.Shared.Return(output, true);
output = ArrayPool.Shared.Rent(outputLength);
// Run again, but this time with an output that is definitely large enough
ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
}
// Grow result buffer if necessary
if (_result.Length < outputLength)
Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength));
// Copy to result buffer
output.AsSpan(0, outputLength).CopyTo(_result);
_resultLength = outputLength;
}
finally
{
ArrayPool.Shared.Return(output, true);
}
}
}
// Now that the template has been applied and is in the result buffer, copy it to the dest
_result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
return _resultLength;
unsafe int ApplyInternal(Span messages, byte[] output)
{
fixed (byte* customTemplatePtr = _customTemplate)
fixed (byte* outputPtr = output)
fixed (LLamaChatMessage* messagesPtr = messages)
{
return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length);
}
}
}
///
/// A message that has been added to a template
///
public sealed class TextMessage
{
///
/// The "role" string for this message
///
public string Role { get; }
///
/// The text content of this message
///
public string Content { get; }
internal ReadOnlyMemory RoleBytes { get; }
internal ReadOnlyMemory ContentBytes { get; }
internal TextMessage(string role, string content, IDictionary> roleCache)
{
Role = role;
Content = content;
// Get bytes for role from cache
if (!roleCache.TryGetValue(role, out var roleBytes))
{
// Convert role. Add one to length so there is a null byte at the end.
var rArr = new byte[Encoding.GetByteCount(role) + 1];
var encodedRoleLength = Encoding.GetBytes(role.AsSpan(), rArr);
Debug.Assert(rArr.Length == encodedRoleLength + 1);
// Add to cache for future use.
// To ensure the cache cannot grow infinitely add a hard limit to size.
if (roleCache.Count < 128)
{
roleCache.Add(role, rArr);
roleBytes = rArr;
}
}
RoleBytes = roleBytes;
// Convert content. Add one to length so there is a null byte at the end.
var contentArray = new byte[Encoding.GetByteCount(content) + 1];
var encodedContentLength = Encoding.GetBytes(content.AsSpan(), contentArray);
Debug.Assert(contentArray.Length == encodedContentLength + 1);
ContentBytes = contentArray;
}
///
/// Deconstruct this message into role and content
///
///
///
public void Deconstruct(out string role, out string content)
{
role = Role;
content = Content;
}
}
}