You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLLamaGrammarHandle.cs 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using LLama.Exceptions;
  8. using LLama.Grammars;
  9. namespace LLama.Native
  10. {
  11. /// <summary>
  12. /// A safe reference to a `llama_grammar`
  13. /// </summary>
  14. public class SafeLLamaGrammarHandle
  15. : SafeLLamaHandleBase
  16. {
  17. #region construction/destruction
  18. /// <summary>
  19. ///
  20. /// </summary>
  21. /// <param name="handle"></param>
  22. internal SafeLLamaGrammarHandle(IntPtr handle)
  23. : base(handle, true)
  24. {
  25. }
  26. /// <inheritdoc />
  27. protected override bool ReleaseHandle()
  28. {
  29. NativeApi.llama_grammar_free(handle);
  30. SetHandle(IntPtr.Zero);
  31. return true;
  32. }
  33. /// <summary>
  34. /// Create a new llama_grammar
  35. /// </summary>
  36. /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
  37. /// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
  38. /// <returns></returns>
  39. /// <exception cref="RuntimeError"></exception>
  40. public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index)
  41. {
  42. unsafe
  43. {
  44. var totalElements = rules.Sum(a => a.Elements.Count);
  45. var nrules = (ulong)rules.Count;
  46. // Borrow an array large enough to hold every single element
  47. // and another array large enough to hold a pointer to each rule
  48. var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
  49. var rulePointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
  50. try
  51. {
  52. // We're taking pointers into `allElements` below, so this pin is required to fix
  53. // that memory in place while those pointers are in use!
  54. using var pin = allElements.AsMemory().Pin();
  55. var elementIndex = 0;
  56. var ruleIndex = 0;
  57. foreach (var rule in rules)
  58. {
  59. // Save a pointer to the start of this rule
  60. rulePointers[ruleIndex++] = (IntPtr)Unsafe.AsPointer(ref allElements[elementIndex]);
  61. // Copy all of the rule elements into the flat array
  62. foreach (var element in rule.Elements)
  63. allElements[elementIndex++] = element;
  64. }
  65. // Sanity check some things that should be true if the copy worked as planned
  66. Debug.Assert((ulong)ruleIndex == nrules);
  67. Debug.Assert(elementIndex == totalElements);
  68. // Make the actual call through to llama.cpp
  69. fixed (void* ptr = rulePointers)
  70. {
  71. return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
  72. }
  73. }
  74. finally
  75. {
  76. ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
  77. ArrayPool<IntPtr>.Shared.Return(rulePointers);
  78. }
  79. }
  80. }
  81. /// <summary>
  82. /// Create a new llama_grammar
  83. /// </summary>
  84. /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
  85. /// <param name="nrules">total number of rules</param>
  86. /// <param name="start_rule_index">index of the start rule of the grammar</param>
  87. /// <returns></returns>
  88. /// <exception cref="RuntimeError"></exception>
  89. public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
  90. {
  91. var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
  92. if (grammar_ptr == IntPtr.Zero)
  93. throw new RuntimeError("Failed to create grammar from rules");
  94. return new(grammar_ptr);
  95. }
  96. #endregion
  97. /// <summary>
  98. /// Create a copy of this grammar instance
  99. /// </summary>
  100. /// <returns></returns>
  101. public SafeLLamaGrammarHandle Clone()
  102. {
  103. return new SafeLLamaGrammarHandle(NativeApi.llama_grammar_copy(this));
  104. }
  105. /// <summary>
  106. /// Accepts the sampled token into the grammar
  107. /// </summary>
  108. /// <param name="ctx"></param>
  109. /// <param name="token"></param>
  110. public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token)
  111. {
  112. NativeApi.llama_grammar_accept_token(ctx, this, token);
  113. }
  114. }
  115. }