You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLLamaGrammarHandle.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using LLama.Exceptions;
  7. using LLama.Grammars;
  8. namespace LLama.Native
  9. {
  10. /// <summary>
  11. /// A safe reference to a `llama_grammar`
  12. /// </summary>
  13. public class SafeLLamaGrammarHandle
  14. : SafeLLamaHandleBase
  15. {
  16. #region construction/destruction
  17. /// <summary>
  18. ///
  19. /// </summary>
  20. /// <param name="handle"></param>
  21. internal SafeLLamaGrammarHandle(IntPtr handle)
  22. : base(handle)
  23. {
  24. }
  25. /// <inheritdoc />
  26. protected override bool ReleaseHandle()
  27. {
  28. NativeApi.llama_grammar_free(handle);
  29. SetHandle(IntPtr.Zero);
  30. return true;
  31. }
  32. /// <summary>
  33. /// Create a new llama_grammar
  34. /// </summary>
  35. /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
  36. /// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
  37. /// <returns></returns>
  38. /// <exception cref="RuntimeError"></exception>
  39. public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index)
  40. {
  41. unsafe
  42. {
  43. var totalElements = rules.Sum(a => a.Elements.Count);
  44. var nrules = (ulong)rules.Count;
  45. // Borrow an array large enough to hold every single element
  46. // and another array large enough to hold a pointer to each rule
  47. var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
  48. var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
  49. try
  50. {
  51. fixed (LLamaGrammarElement* allElementsPtr = allElements)
  52. {
  53. var elementIndex = 0;
  54. var pointerIndex = 0;
  55. foreach (var rule in rules)
  56. {
  57. // Save a pointer to the start of this rule
  58. pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
  59. // Copy all of the rule elements into the flat array
  60. foreach (var element in rule.Elements)
  61. allElementsPtr[elementIndex++] = element;
  62. }
  63. // Sanity check some things that should be true if the copy worked as planned
  64. Debug.Assert((ulong)pointerIndex == nrules);
  65. Debug.Assert(elementIndex == totalElements);
  66. // Make the actual call through to llama.cpp
  67. fixed (void* ptr = pointers)
  68. {
  69. return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
  70. }
  71. }
  72. }
  73. finally
  74. {
  75. ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
  76. ArrayPool<IntPtr>.Shared.Return(pointers);
  77. }
  78. }
  79. }
  80. /// <summary>
  81. /// Create a new llama_grammar
  82. /// </summary>
  83. /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
  84. /// <param name="nrules">total number of rules</param>
  85. /// <param name="start_rule_index">index of the start rule of the grammar</param>
  86. /// <returns></returns>
  87. /// <exception cref="RuntimeError"></exception>
  88. public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
  89. {
  90. var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
  91. if (grammar_ptr == IntPtr.Zero)
  92. throw new RuntimeError("Failed to create grammar from rules");
  93. return new(grammar_ptr);
  94. }
  95. #endregion
  96. }
  97. }