using System; using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; using LLama.Exceptions; using LLama.Grammars; namespace LLama.Native { /// /// A safe reference to a `llama_grammar` /// public class SafeLLamaGrammarHandle : SafeLLamaHandleBase { #region construction/destruction /// protected override bool ReleaseHandle() { NativeApi.llama_grammar_free(handle); SetHandle(IntPtr.Zero); return true; } /// /// Create a new llama_grammar /// /// A list of list of elements, each inner list makes up one grammar rule /// The index (in the outer list) of the start rule /// /// public static SafeLLamaGrammarHandle Create(IReadOnlyList rules, ulong start_rule_index) { unsafe { var totalElements = rules.Sum(a => a.Elements.Count); var nrules = (ulong)rules.Count; // Borrow an array large enough to hold every single element // and another array large enough to hold a pointer to each rule var allElements = ArrayPool.Shared.Rent(totalElements); var rulePointers = ArrayPool.Shared.Rent(rules.Count); try { // We're taking pointers into `allElements` below, so this pin is required to fix // that memory in place while those pointers are in use! using var pin = allElements.AsMemory().Pin(); var elementIndex = 0; var ruleIndex = 0; foreach (var rule in rules) { // Save a pointer to the start of this rule rulePointers[ruleIndex++] = (IntPtr)Unsafe.AsPointer(ref allElements[elementIndex]); // Copy all of the rule elements into the flat array foreach (var element in rule.Elements) allElements[elementIndex++] = element; } // Sanity check some things that should be true if the copy worked as planned Debug.Assert((ulong)ruleIndex == nrules); Debug.Assert(elementIndex == totalElements); // Make the actual call through to llama.cpp fixed (void* ptr = rulePointers) { return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index); } } finally { ArrayPool.Shared.Return(allElements); ArrayPool.Shared.Return(rulePointers); } } } /// /// Create a new llama_grammar /// /// rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element) /// total number of rules /// index of the start rule of the grammar /// /// public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index) { var grammar = NativeApi.llama_grammar_init(rules, nrules, start_rule_index); if (grammar is null) throw new RuntimeError("Failed to create grammar from rules"); return grammar; } #endregion /// /// Create a copy of this grammar instance /// /// public SafeLLamaGrammarHandle Clone() { return NativeApi.llama_grammar_copy(this); } /// /// Accepts the sampled token into the grammar /// /// /// public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token) { NativeApi.llama_grammar_accept_token(ctx, this, token); } } }