using System; using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using LLama.Exceptions; using LLama.Grammars; namespace LLama.Native { /// /// A safe reference to a `llama_grammar` /// public class SafeLLamaGrammarHandle : SafeLLamaHandleBase { #region construction/destruction /// /// /// /// internal SafeLLamaGrammarHandle(IntPtr handle) : base(handle) { } /// protected override bool ReleaseHandle() { NativeApi.llama_grammar_free(handle); SetHandle(IntPtr.Zero); return true; } /// /// Create a new llama_grammar /// /// A list of list of elements, each inner list makes up one grammar rule /// The index (in the outer list) of the start rule /// /// public static SafeLLamaGrammarHandle Create(IReadOnlyList rules, ulong start_rule_index) { unsafe { var totalElements = rules.Sum(a => a.Elements.Count); var nrules = (ulong)rules.Count; // Borrow an array large enough to hold every single element // and another array large enough to hold a pointer to each rule var allElements = ArrayPool.Shared.Rent(totalElements); var pointers = ArrayPool.Shared.Rent(rules.Count); try { fixed (LLamaGrammarElement* allElementsPtr = allElements) { var elementIndex = 0; var pointerIndex = 0; foreach (var rule in rules) { // Save a pointer to the start of this rule pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); // Copy all of the rule elements into the flat array foreach (var element in rule.Elements) allElementsPtr[elementIndex++] = element; } // Sanity check some things that should be true if the copy worked as planned Debug.Assert((ulong)pointerIndex == nrules); Debug.Assert(elementIndex == totalElements); // Make the actual call through to llama.cpp fixed (void* ptr = pointers) { return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); } } } finally { ArrayPool.Shared.Return(allElements); ArrayPool.Shared.Return(pointers); } } } /// /// Create a new llama_grammar /// /// rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element) /// total number of rules /// index of the start rule of the grammar /// /// public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) { var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); if (grammar_ptr == IntPtr.Zero) throw new RuntimeError("Failed to create grammar from rules"); return new(grammar_ptr); } #endregion } }