Browse Source

- Added `LLamaTemplate` which efficiently formats a series of messages according to the model template.

- Fixed `llama_chat_apply_template` method (wrong entrypoint, couldn't handle null model)
pull/715/head
Martin Evans 1 year ago
parent
commit
a0335f67a4
5 changed files with 566 additions and 4 deletions
  1. +245
    -0
      LLama.Unittest/TemplateTests.cs
  2. +301
    -0
      LLama/LLamaTemplate.cs
  3. +11
    -1
      LLama/Native/LLamaChatMessage.cs
  4. +8
    -3
      LLama/Native/NativeApi.cs
  5. +1
    -0
      LLama/Usings.cs

+ 245
- 0
LLama.Unittest/TemplateTests.cs View File

@@ -0,0 +1,245 @@
using System.Text;
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;
public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact]
public void BasicTemplate()
{
var templater = new LLamaTemplate(_model);

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void CustomTemplate()
{
var templater = new LLamaTemplate("gemma");

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(4, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(4, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
"world<end_of_turn>\n" +
"<start_of_turn>model\n" +
"111<end_of_turn>\n" +
"<start_of_turn>user\n" +
"aaa<end_of_turn>\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void BasicTemplateWithAddAssistant()
{
var templater = new LLamaTemplate(_model)
{
AddAssistant = true,
};

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n" +
"<|im_start|>assistant\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void GetOutOfRangeThrows()
{
var templater = new LLamaTemplate(_model);

Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]);

templater.Add("assistant", "1");
templater.Add("user", "2");

Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]);
Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]);
}

[Fact]
public void RemoveMid()
{
var templater = new LLamaTemplate(_model);

templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");

Assert.Equal(("user", "4a"), templater[3]);
Assert.Equal(("assistant", "5"), templater[5]);

Assert.Equal(6, templater.Count);
templater.RemoveAt(3);
Assert.Equal(5, templater.Count);

Assert.Equal(("user", "4b"), templater[3]);
Assert.Equal(("assistant", "5"), templater[4]);
}

[Fact]
public void RemoveLast()
{
var templater = new LLamaTemplate(_model);

templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");

Assert.Equal(6, templater.Count);
templater.RemoveAt(5);
Assert.Equal(5, templater.Count);

Assert.Equal(("user", "4b"), templater[4]);
}

[Fact]
public void RemoveOutOfRange()
{
var templater = new LLamaTemplate(_model);

Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0));

templater.Add("assistant", "1");
templater.Add("user", "2");

Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}
}

+ 301
- 0
LLama/LLamaTemplate.cs View File

@@ -0,0 +1,301 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using LLama.Native;

namespace LLama;

/// <summary>
/// Converts a sequence of messages into text according to a model template
/// </summary>
public sealed class LLamaTemplate
{
#region private state
/// <summary>
/// The model this template is for. May be null if a custom template was supplied to the constructor.
/// </summary>
private readonly SafeLlamaModelHandle? _model;

/// <summary>
/// Custom template. May be null if a model was supplied to the constructor.
/// </summary>
private readonly byte[]? _customTemplate;

/// <summary>
/// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
/// </summary>
private readonly Dictionary<string, ReadOnlyMemory<byte>> _roleCache = new();

/// <summary>
/// Array of messages. The <see cref="Count"/> property indicates how many messages there are
/// </summary>
private Message[] _messages = new Message[4];

/// <summary>
/// Backing field for <see cref="AddAssistant"/>
/// </summary>
private bool _addAssistant;

/// <summary>
/// Temporary array of messages in the format llama.cpp needs, used when applying the template
/// </summary>
private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4];

/// <summary>
/// Indicates how many bytes are in <see cref="_result"/> array
/// </summary>
private int _resultLength;

/// <summary>
/// Result bytes of last call to <see cref="Apply"/>
/// </summary>
private byte[] _result = Array.Empty<byte>();

/// <summary>
/// Indicates if this template has been modified and needs regenerating
/// </summary>
private bool _dirty = true;
#endregion

#region properties
/// <summary>
/// Number of messages added to this template
/// </summary>
public int Count { get; private set; }

/// <summary>
/// Get the message at the given index
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is less than zero or greater than or equal to <see cref="Count"/></exception>
public (string role, string content) this[int index]
{
get
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0");
if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count");

return (_messages[index].Role, _messages[index].Content);
}
}

/// <summary>
/// Whether to end the prompt with the token(s) that indicate the start of an assistant message.
/// </summary>
public bool AddAssistant
{
get => _addAssistant;
set
{
if (value != _addAssistant)
{
_dirty = true;
_addAssistant = value;
}
}
}
#endregion

#region construction
/// <summary>
/// Construct a new template, using the default model template
/// </summary>
/// <param name="model"></param>
public LLamaTemplate(SafeLlamaModelHandle model)
{
_model = model;
}

/// <summary>
/// Construct a new template, using the default model template
/// </summary>
/// <param name="weights"></param>
public LLamaTemplate(LLamaWeights weights)
: this(weights.NativeHandle)
{
}

/// <summary>
/// Construct a new template, using a custom template.
/// </summary>
/// <remarks>Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template</remarks>
/// <param name="customTemplate"></param>
public LLamaTemplate(string customTemplate)
{
_customTemplate = Encoding.UTF8.GetBytes(customTemplate + "\0");
}
#endregion

/// <summary>
/// Add a new message to the end of this template
/// </summary>
/// <param name="role"></param>
/// <param name="content"></param>
public void Add(string role, string content)
{
// Expand messages array if necessary
if (Count == _messages.Length)
Array.Resize(ref _messages, _messages.Length * 2);

// Add message
_messages[Count] = new Message(role, content, _roleCache);
Count++;

// Mark as dirty to ensure template is recalculated
_dirty = true;
}

/// <summary>
/// Remove a message at the given index
/// </summary>
/// <param name="index"></param>
public void RemoveAt(int index)
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero");
if (index >= Count)
throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count");

_dirty = true;
Count--;

// Copy all items after index down by one
if (index < Count)
Array.Copy(_messages, index + 1, _messages, index, Count - index);

_messages[Count] = default;
}

/// <summary>
/// Apply the template to the messages and write it into the output buffer
/// </summary>
/// <param name="dest">Destination to write template bytes into</param>
/// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns>
public int Apply(Memory<byte> dest)
{
// Recalculate template if necessary
if (_dirty)
{
_dirty = false;

using var group = new GroupDisposable();
unsafe
{
// Convert all the messages
var totalInputBytes = 0;
if (_nativeChatMessages.Length < _messages.Length)
Array.Resize(ref _nativeChatMessages, _messages.Length);
for (var i = 0; i < Count; i++)
{
ref var m = ref _messages[i];
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;

// Pin byte arrays in place
var r = m.RoleBytes.Pin();
group.Add(r);
var c = m.ContentBytes.Pin();
group.Add(c);

_nativeChatMessages[i] = new LLamaChatMessage
{
role = (byte*)r.Pointer,
content = (byte*)c.Pointer
};
}

// Get an array that's twice as large as the amount of input, hopefully that's large enough!
var output = ArrayPool<byte>.Shared.Rent(Math.Max(32, totalInputBytes * 2));
try
{

// Run templater and discover true length
var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);

// If length was too big for output buffer run it again
if (outputLength > output.Length)
{
// Array was too small, rent another one that's exactly the size needed
ArrayPool<byte>.Shared.Return(output, true);
output = ArrayPool<byte>.Shared.Rent(outputLength);

// Run again, but this time with an output that is definitely large enough
ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
}

// Grow result buffer if necessary
if (_result.Length < outputLength)
Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength));

// Copy to result buffer
output.AsSpan(0, outputLength).CopyTo(_result);
_resultLength = outputLength;
}
finally
{
ArrayPool<byte>.Shared.Return(output, true);
}
}
}

// Now that the template has been applied and is in the result buffer, copy it to the dest
_result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
return _resultLength;

unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output)
{
fixed (byte* customTemplatePtr = _customTemplate)
fixed (byte* outputPtr = output)
fixed (LLamaChatMessage* messagesPtr = messages)
{
return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length);
}
}
}

/// <summary>
/// A message that has been added to the template, contains role and content converted into UTF8 bytes.
/// </summary>
private readonly record struct Message
{
public string Role { get; }
public string Content { get; }

public ReadOnlyMemory<byte> RoleBytes { get; }
public ReadOnlyMemory<byte> ContentBytes { get; }

public Message(string role, string content, Dictionary<string, ReadOnlyMemory<byte>> roleCache)
{
Role = role;
Content = content;

// Get bytes for role from cache
if (!roleCache.TryGetValue(role, out var roleBytes))
{
// Convert role. Add one to length so there is a null byte at the end.
var rArr = new byte[Encoding.UTF8.GetByteCount(role) + 1];
var encodedRoleLength = Encoding.UTF8.GetBytes(role.AsSpan(), rArr);
Debug.Assert(rArr.Length == encodedRoleLength + 1);

// Add to cache for future use.
// To ensure the cache cannot grow infinitely add a hard limit to size.
if (roleCache.Count < 128)
{
roleCache.Add(role, rArr);
roleBytes = rArr;
}
}
RoleBytes = roleBytes;

// Convert content. Add one to length so there is a null byte at the end.
var contentArray = new byte[Encoding.UTF8.GetByteCount(content) + 1];
var encodedContentLength = Encoding.UTF8.GetBytes(content.AsSpan(), contentArray);
Debug.Assert(contentArray.Length == encodedContentLength + 1);
ContentBytes = contentArray;
}
}
}

+ 11
- 1
LLama/Native/LLamaChatMessage.cs View File

@@ -1,11 +1,21 @@
namespace LLama.Native;
using System.Runtime.InteropServices;

namespace LLama.Native;


/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
/// <remarks>llama_chat_message</remarks> /// <remarks>llama_chat_message</remarks>
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaChatMessage public unsafe struct LLamaChatMessage
{ {
/// <summary>
/// Pointer to the null terminated bytes that make up the role string
/// </summary>
public byte* role; public byte* role;

/// <summary>
/// Pointer to the null terminated bytes that make up the content string
/// </summary>
public byte* content; public byte* content;
} }

+ 8
- 3
LLama/Native/NativeApi.cs View File

@@ -1,4 +1,4 @@
using System;
using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;


#pragma warning disable IDE1006 // Naming Styles #pragma warning disable IDE1006 // Naming Styles
@@ -174,8 +174,13 @@ namespace LLama.Native
/// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param> /// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param>
/// <param name="length">The size of the allocated buffer</param> /// <param name="length">The size of the allocated buffer</param>
/// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns> /// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length);
public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length)
{
return internal_llama_chat_apply_template(model?.DangerousGetHandle() ?? IntPtr.Zero, tmpl, chat, n_msg, add_ass, buf, length);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(IntPtr model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length);
}


/// <summary> /// <summary>
/// Returns -1 if unknown, 1 for true or 0 for false. /// Returns -1 if unknown, 1 for true or 0 for false.


+ 1
- 0
LLama/Usings.cs View File

@@ -0,0 +1 @@
global using LLama.Extensions;

Loading…
Cancel
Save