You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ParseState.cs 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. using LLama.Exceptions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. namespace LLama.Grammar
  7. {
  8. /// <summary>
  9. /// Source:
  10. /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h
  11. ///
  12. /// The commit hash from URL is the actual commit hash that reflects current C# code.
  13. /// </summary>
  14. public class ParseState
  15. {
  16. public SortedDictionary<string, uint> SymbolIds { get; } = new SortedDictionary<string, uint>();
  17. public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>();
  18. public IEnumerable<List<LLamaGrammarElement>> CRules()
  19. {
  20. foreach (var rule in Rules)
  21. {
  22. yield return rule;
  23. }
  24. }
  25. public void PrintGrammar(StreamWriter file, ParseState state)
  26. {
  27. try
  28. {
  29. Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>();
  30. foreach (var kv in state.SymbolIds)
  31. {
  32. symbolIdNames[kv.Value] = kv.Key;
  33. }
  34. for (int i = 0, end = state.Rules.Count; i < end; i++)
  35. {
  36. PrintRule(file, (uint)i, state.Rules[i], symbolIdNames);
  37. }
  38. }
  39. catch(Exception err)
  40. {
  41. Console.Error.WriteLine($"\nError printing grammar: {err.Message}");
  42. }
  43. }
  44. public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule)
  45. {
  46. foreach (var elem in rule)
  47. {
  48. switch (elem.Type)
  49. {
  50. case LLamaGrammarElementType.END: file.Write("END"); break;
  51. case LLamaGrammarElementType.ALT: file.Write("ALT"); break;
  52. case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break;
  53. case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break;
  54. case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break;
  55. case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break;
  56. case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break;
  57. }
  58. switch (elem.Type)
  59. {
  60. case LLamaGrammarElementType.END:
  61. case LLamaGrammarElementType.ALT:
  62. case LLamaGrammarElementType.RULE_REF:
  63. file.Write($"({elem.Value}) ");
  64. break;
  65. case LLamaGrammarElementType.CHAR:
  66. case LLamaGrammarElementType.CHAR_NOT:
  67. case LLamaGrammarElementType.CHAR_RNG_UPPER:
  68. case LLamaGrammarElementType.CHAR_ALT:
  69. file.Write("(\"");
  70. PrintGrammarChar(file, elem.Value);
  71. file.Write("\") ");
  72. break;
  73. }
  74. }
  75. file.WriteLine();
  76. }
  77. private void PrintRule(
  78. StreamWriter file,
  79. uint ruleId,
  80. List<LLamaGrammarElement> rule,
  81. Dictionary<uint, string> symbolIdNames)
  82. {
  83. if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END)
  84. {
  85. throw new GrammarFormatException(
  86. $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}");
  87. }
  88. file.Write($"{symbolIdNames[ruleId]} ::= ");
  89. for (int i = 0, end = rule.Count - 1; i < end; i++)
  90. {
  91. var elem = rule[i];
  92. switch (elem.Type)
  93. {
  94. case LLamaGrammarElementType.END:
  95. throw new GrammarFormatException(
  96. $"Unexpected end of rule: {ruleId}, {i}");
  97. case LLamaGrammarElementType.ALT:
  98. file.Write("| ");
  99. break;
  100. case LLamaGrammarElementType.RULE_REF:
  101. file.Write($"{symbolIdNames[elem.Value]} ");
  102. break;
  103. case LLamaGrammarElementType.CHAR:
  104. file.Write("[");
  105. PrintGrammarChar(file, elem.Value);
  106. break;
  107. case LLamaGrammarElementType.CHAR_NOT:
  108. file.Write("[^");
  109. PrintGrammarChar(file, elem.Value);
  110. break;
  111. case LLamaGrammarElementType.CHAR_RNG_UPPER:
  112. if (i == 0 || !IsCharElement(rule[i - 1]))
  113. {
  114. throw new GrammarFormatException(
  115. $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
  116. }
  117. file.Write("-");
  118. PrintGrammarChar(file, elem.Value);
  119. break;
  120. case LLamaGrammarElementType.CHAR_ALT:
  121. if (i == 0 || !IsCharElement(rule[i - 1]))
  122. {
  123. throw new GrammarFormatException(
  124. $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
  125. }
  126. PrintGrammarChar(file, elem.Value);
  127. break;
  128. }
  129. if (IsCharElement(elem))
  130. {
  131. switch (rule[i + 1].Type)
  132. {
  133. case LLamaGrammarElementType.CHAR_ALT:
  134. case LLamaGrammarElementType.CHAR_RNG_UPPER:
  135. break;
  136. default:
  137. file.Write("] ");
  138. break;
  139. }
  140. }
  141. }
  142. file.WriteLine();
  143. }
  144. private void PrintGrammarChar(StreamWriter file, uint c)
  145. {
  146. if (c >= 0x20 && c <= 0x7F)
  147. {
  148. file.Write((char)c);
  149. }
  150. else
  151. {
  152. // cop out of encoding UTF-8
  153. file.Write($"<U+{c:X4}>");
  154. }
  155. }
  156. private bool IsCharElement(LLamaGrammarElement elem)
  157. {
  158. switch (elem.Type)
  159. {
  160. case LLamaGrammarElementType.CHAR:
  161. case LLamaGrammarElementType.CHAR_NOT:
  162. case LLamaGrammarElementType.CHAR_ALT:
  163. case LLamaGrammarElementType.CHAR_RNG_UPPER:
  164. return true;
  165. default:
  166. return false;
  167. }
  168. }
  169. }
  170. }