You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SafeLLamaGrammarHandle.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. using System;
  2. using System.Buffers;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Runtime.CompilerServices;
  7. using LLama.Exceptions;
  8. using LLama.Grammars;
  9. namespace LLama.Native
  10. {
  11. /// <summary>
  12. /// A safe reference to a `llama_grammar`
  13. /// </summary>
  14. public class SafeLLamaGrammarHandle
  15. : SafeLLamaHandleBase
  16. {
  17. #region construction/destruction
  18. /// <inheritdoc />
  19. protected override bool ReleaseHandle()
  20. {
  21. NativeApi.llama_grammar_free(handle);
  22. SetHandle(IntPtr.Zero);
  23. return true;
  24. }
  25. /// <summary>
  26. /// Create a new llama_grammar
  27. /// </summary>
  28. /// <param name="rules">A list of list of elements, each inner list makes up one grammar rule</param>
  29. /// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
  30. /// <returns></returns>
  31. /// <exception cref="RuntimeError"></exception>
  32. public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index)
  33. {
  34. unsafe
  35. {
  36. var totalElements = rules.Sum(a => a.Elements.Count);
  37. var nrules = (ulong)rules.Count;
  38. // Borrow an array large enough to hold every single element
  39. // and another array large enough to hold a pointer to each rule
  40. var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
  41. var rulePointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
  42. try
  43. {
  44. // We're taking pointers into `allElements` below, so this pin is required to fix
  45. // that memory in place while those pointers are in use!
  46. using var pin = allElements.AsMemory().Pin();
  47. var elementIndex = 0;
  48. var ruleIndex = 0;
  49. foreach (var rule in rules)
  50. {
  51. // Save a pointer to the start of this rule
  52. rulePointers[ruleIndex++] = (IntPtr)Unsafe.AsPointer(ref allElements[elementIndex]);
  53. // Copy all of the rule elements into the flat array
  54. foreach (var element in rule.Elements)
  55. allElements[elementIndex++] = element;
  56. }
  57. // Sanity check some things that should be true if the copy worked as planned
  58. Debug.Assert((ulong)ruleIndex == nrules);
  59. Debug.Assert(elementIndex == totalElements);
  60. // Make the actual call through to llama.cpp
  61. fixed (void* ptr = rulePointers)
  62. {
  63. return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
  64. }
  65. }
  66. finally
  67. {
  68. ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
  69. ArrayPool<IntPtr>.Shared.Return(rulePointers);
  70. }
  71. }
  72. }
  73. /// <summary>
  74. /// Create a new llama_grammar
  75. /// </summary>
  76. /// <param name="rules">rules list, each rule is a list of rule elements (terminated by a LLamaGrammarElementType.END element)</param>
  77. /// <param name="nrules">total number of rules</param>
  78. /// <param name="start_rule_index">index of the start rule of the grammar</param>
  79. /// <returns></returns>
  80. /// <exception cref="RuntimeError"></exception>
  81. public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, ulong nrules, ulong start_rule_index)
  82. {
  83. var grammar = NativeApi.llama_grammar_init(rules, nrules, start_rule_index);
  84. if (grammar is null)
  85. throw new RuntimeError("Failed to create grammar from rules");
  86. return grammar;
  87. }
  88. #endregion
  89. /// <summary>
  90. /// Create a copy of this grammar instance
  91. /// </summary>
  92. /// <returns></returns>
  93. public SafeLLamaGrammarHandle Clone()
  94. {
  95. return NativeApi.llama_grammar_copy(this);
  96. }
  97. /// <summary>
  98. /// Accepts the sampled token into the grammar
  99. /// </summary>
  100. /// <param name="ctx"></param>
  101. /// <param name="token"></param>
  102. public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token)
  103. {
  104. NativeApi.llama_grammar_accept_token(ctx, this, token);
  105. }
  106. }
  107. }