using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using LLama.Exceptions;
using LLama.Grammars;
namespace LLama.Native
{
///
/// A safe reference to a `llama_grammar`
///
public class SafeLLamaGrammarHandle
: SafeLLamaHandleBase
{
#region construction/destruction
///
///
///
///
internal SafeLLamaGrammarHandle(IntPtr handle)
: base(handle)
{
}
///
protected override bool ReleaseHandle()
{
NativeApi.llama_grammar_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
///
/// Create a new llama_grammar
///
/// A list of list of elements, each inner list makes up one grammar rule
/// The index (in the outer list) of the start rule
///
///
public static SafeLLamaGrammarHandle Create(IReadOnlyList rules, ulong start_rule_index)
{
unsafe
{
var totalElements = rules.Sum(a => a.Elements.Count);
var nrules = (ulong)rules.Count;
// Borrow an array large enough to hold every single element
// and another array large enough to hold a pointer to each rule
var allElements = ArrayPool.Shared.Rent(totalElements);
var pointers = ArrayPool.Shared.Rent(rules.Count);
try
{
fixed (LLamaGrammarElement* allElementsPtr = allElements)
{
var elementIndex = 0;
var pointerIndex = 0;
foreach (var rule in rules)
{
// Save a pointer to the start of this rule
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
// Copy all of the rule elements into the flat array
foreach (var element in rule.Elements)
allElementsPtr[elementIndex++] = element;
}
// Sanity check some things that should be true if the copy worked as planned
Debug.Assert((ulong)pointerIndex == nrules);
Debug.Assert(elementIndex == totalElements);
// Make the actual call through to llama.cpp
fixed (void* ptr = pointers)
{
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
}
}
}
finally
{
ArrayPool.Shared.Return(allElements);
ArrayPool.Shared.Return(pointers);
}
}
}
///
/// Create a new llama_grammar
///
/// rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)
/// total number of rules
/// index of the start rule of the grammar
///
///
public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
{
var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
if (grammar_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create grammar from rules");
return new(grammar_ptr);
}
#endregion
}
}