You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLLamaGrammarHandle.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using LLama.Exceptions;
  7. namespace LLama.Native
  8. {
  9. /// <summary>
  10. /// A safe reference to a `llama_grammar`
  11. /// </summary>
  12. public class SafeLLamaGrammarHandle
  13. : SafeLLamaHandleBase
  14. {
  15. #region construction/destruction
  16. /// <summary>
  17. ///
  18. /// </summary>
  19. /// <param name="handle"></param>
  20. internal SafeLLamaGrammarHandle(IntPtr handle)
  21. : base(handle)
  22. {
  23. }
  24. /// <inheritdoc />
  25. protected override bool ReleaseHandle()
  26. {
  27. NativeApi.llama_grammar_free(handle);
  28. SetHandle(IntPtr.Zero);
  29. return true;
  30. }
  31. /// <summary>
  32. /// Create a new llama_grammar
  33. /// </summary>
  34. /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
  35. /// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
  36. /// <returns></returns>
  37. /// <exception cref="RuntimeError"></exception>
  38. public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
  39. {
  40. unsafe
  41. {
  42. var totalElements = rules.Sum(a => a.Count);
  43. var nrules = (ulong)rules.Count;
  44. // Borrow an array large enough to hold every single element
  45. // and another array large enough to hold a pointer to each rule
  46. var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
  47. var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
  48. try
  49. {
  50. fixed (LLamaGrammarElement* allElementsPtr = allElements)
  51. {
  52. var elementIndex = 0;
  53. var pointerIndex = 0;
  54. foreach (var rule in rules)
  55. {
  56. // Save a pointer to the start of this rule
  57. pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
  58. // Copy all of the rule elements into the flat array
  59. foreach (var element in rule)
  60. allElementsPtr[elementIndex++] = element;
  61. }
  62. // Sanity check some things that should be true if the copy worked as planned
  63. Debug.Assert((ulong)pointerIndex == nrules);
  64. Debug.Assert(elementIndex == totalElements);
  65. // Make the actual call through to llama.cpp
  66. fixed (void* ptr = pointers)
  67. {
  68. return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
  69. }
  70. }
  71. }
  72. finally
  73. {
  74. ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
  75. ArrayPool<IntPtr>.Shared.Return(pointers);
  76. }
  77. }
  78. }
  79. /// <summary>
  80. /// Create a new llama_grammar
  81. /// </summary>
  82. /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
  83. /// <param name="nrules">total number of rules</param>
  84. /// <param name="start_rule_index">index of the start rule of the grammar</param>
  85. /// <returns></returns>
  86. /// <exception cref="RuntimeError"></exception>
  87. public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
  88. {
  89. var grammar_ptr = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
  90. if (grammar_ptr == IntPtr.Zero)
  91. throw new RuntimeError("Failed to create grammar from rules");
  92. return new(grammar_ptr);
  93. }
  94. #endregion
  95. }
  96. }