You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GBNFGrammarParser.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using LLama.Native;
  7. namespace LLama.Grammars
  8. {
  9. /// <summary>
  10. /// Source:
  11. /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp
  12. ///
  13. /// The commit hash from URL is the actual commit hash that reflects current C# code.
  14. /// </summary>
  15. internal sealed class GBNFGrammarParser
  16. {
  17. // NOTE: assumes valid utf8 (but checks for overrun)
  18. // copied from llama.cpp
  19. private uint DecodeUTF8(ref ReadOnlySpan<byte> src)
  20. {
  21. int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  22. byte firstByte = src[0];
  23. byte highbits = (byte)(firstByte >> 4);
  24. int len = lookup[highbits];
  25. byte mask = (byte)((1 << (8 - len)) - 1);
  26. uint value = (uint)(firstByte & mask);
  27. int end = len;
  28. int pos = 1;
  29. for (; pos < end && pos < src.Length; pos++)
  30. {
  31. value = (uint)((value << 6) + (src[pos] & 0x3F));
  32. }
  33. src = src.Slice(pos);
  34. return value;
  35. }
  36. private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
  37. {
  38. uint nextId = (uint)state.SymbolIds.Count;
  39. string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
  40. if (state.SymbolIds.TryGetValue(key, out uint existingId))
  41. {
  42. return existingId;
  43. }
  44. else
  45. {
  46. state.SymbolIds[key] = nextId;
  47. return nextId;
  48. }
  49. }
  50. private uint GenerateSymbolId(ParseState state, string baseName)
  51. {
  52. uint nextId = (uint)state.SymbolIds.Count;
  53. string key = $"{baseName}_{nextId}";
  54. state.SymbolIds[key] = nextId;
  55. return nextId;
  56. }
  57. private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
  58. {
  59. while (state.Rules.Count <= ruleId)
  60. {
  61. state.Rules.Add(new List<LLamaGrammarElement>());
  62. }
  63. state.Rules[(int)ruleId] = rule;
  64. }
  65. private bool IsWordChar(byte c)
  66. {
  67. return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
  68. }
  69. private uint ParseHex(ref ReadOnlySpan<byte> src, int size)
  70. {
  71. int pos = 0;
  72. int end = size;
  73. uint value = 0;
  74. for (; pos < end && pos < src.Length; pos++)
  75. {
  76. value <<= 4;
  77. byte c = src[pos];
  78. if ('a' <= c && c <= 'f')
  79. {
  80. value += (uint)(c - 'a' + 10);
  81. }
  82. else if ('A' <= c && c <= 'F')
  83. {
  84. value += (uint)(c - 'A' + 10);
  85. }
  86. else if ('0' <= c && c <= '9')
  87. {
  88. value += (uint)(c - '0');
  89. }
  90. else
  91. {
  92. break;
  93. }
  94. }
  95. if (pos != end)
  96. {
  97. throw new GrammarFormatException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}");
  98. }
  99. src = src.Slice(pos);
  100. return value;
  101. }
  102. private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
  103. {
  104. int pos = 0;
  105. while (pos < src.Length &&
  106. (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' ||
  107. (newlineOk && (src[pos] == '\r' || src[pos] == '\n'))))
  108. {
  109. if (src[pos] == '#')
  110. {
  111. while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n')
  112. {
  113. pos++;
  114. }
  115. }
  116. else
  117. {
  118. pos++;
  119. }
  120. }
  121. return src.Slice(pos);
  122. }
  123. private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
  124. {
  125. int pos = 0;
  126. while (pos < src.Length && IsWordChar(src[pos]))
  127. {
  128. pos++;
  129. }
  130. if (pos == 0)
  131. {
  132. throw new GrammarFormatException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}");
  133. }
  134. return src.Slice(pos);
  135. }
  136. private uint ParseChar(ref ReadOnlySpan<byte> src)
  137. {
  138. if (src[0] == '\\')
  139. {
  140. var chr = src[1];
  141. src = src.Slice(2);
  142. switch (chr)
  143. {
  144. case (byte)'x':
  145. return ParseHex(ref src, 2);
  146. case (byte)'u':
  147. return ParseHex(ref src, 4);
  148. case (byte)'U':
  149. return ParseHex(ref src, 8);
  150. case (byte)'t':
  151. return '\t';
  152. case (byte)'r':
  153. return '\r';
  154. case (byte)'n':
  155. return '\n';
  156. case (byte)'\\':
  157. case (byte)'"':
  158. case (byte)'[':
  159. case (byte)']':
  160. return chr;
  161. default:
  162. throw new GrammarFormatException("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
  163. }
  164. }
  165. else if (!src.IsEmpty)
  166. {
  167. return DecodeUTF8(ref src);
  168. }
  169. throw new GrammarFormatException("Unexpected end of input");
  170. }
  171. private ReadOnlySpan<byte> ParseSequence(
  172. ParseState state,
  173. ReadOnlySpan<byte> pos,
  174. string ruleName,
  175. List<LLamaGrammarElement> outElements,
  176. bool isNested)
  177. {
  178. int lastSymStart = outElements.Count;
  179. while (!pos.IsEmpty)
  180. {
  181. if (pos[0] == '"') // literal string
  182. {
  183. pos = pos.Slice(1);
  184. lastSymStart = outElements.Count;
  185. while (!pos.IsEmpty && pos[0] != '"')
  186. {
  187. var charPair = ParseChar(ref pos);
  188. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair));
  189. }
  190. pos = ParseSpace(pos.Slice(1), isNested);
  191. }
  192. else if (pos[0] == '[') // char range(s)
  193. {
  194. pos = pos.Slice(1);
  195. var startType = LLamaGrammarElementType.CHAR;
  196. if (pos[0] == '^')
  197. {
  198. pos = pos.Slice(1);
  199. startType = LLamaGrammarElementType.CHAR_NOT;
  200. }
  201. lastSymStart = outElements.Count;
  202. while (!pos.IsEmpty && pos[0] != ']')
  203. {
  204. var charPair = ParseChar(ref pos);
  205. var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;
  206. outElements.Add(new LLamaGrammarElement(type, charPair));
  207. if (pos[0] == '-' && pos[1] != ']')
  208. {
  209. pos = pos.Slice(1);
  210. var endCharPair = ParseChar(ref pos);
  211. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair));
  212. }
  213. }
  214. pos = ParseSpace(pos.Slice(1), isNested);
  215. }
  216. else if (IsWordChar(pos[0])) // rule reference
  217. {
  218. var nameEnd = ParseName(pos);
  219. uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
  220. pos = ParseSpace(nameEnd, isNested);
  221. lastSymStart = outElements.Count;
  222. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
  223. }
  224. else if (pos[0] == '(') // grouping
  225. {
  226. // parse nested alternates into synthesized rule
  227. pos = ParseSpace(pos.Slice(1), true);
  228. uint subRuleId = GenerateSymbolId(state, ruleName);
  229. pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
  230. lastSymStart = outElements.Count;
  231. // output reference to synthesized rule
  232. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  233. if (pos[0] != ')')
  234. {
  235. throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
  236. }
  237. pos = ParseSpace(pos.Slice(1), isNested);
  238. }
  239. else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
  240. {
  241. if (lastSymStart == outElements.Count)
  242. {
  243. throw new GrammarFormatException($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
  244. }
  245. // apply transformation to previous symbol (lastSymStart to end) according to
  246. // rewrite rules:
  247. // S* --> S' ::= S S' |
  248. // S+ --> S' ::= S S' | S
  249. // S? --> S' ::= S |
  250. uint subRuleId = GenerateSymbolId(state, ruleName);
  251. List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();
  252. // add preceding symbol to generated rule
  253. subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
  254. if (pos[0] == '*' || pos[0] == '+')
  255. {
  256. // cause generated rule to recurse
  257. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  258. }
  259. // mark start of alternate def
  260. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
  261. if (pos[0] == '+')
  262. {
  263. // add preceding symbol as alternate only for '+' (otherwise empty)
  264. subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
  265. }
  266. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
  267. AddRule(state, subRuleId, subRule);
  268. // in original rule, replace previous symbol with reference to generated rule
  269. outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
  270. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  271. pos = ParseSpace(pos.Slice(1), isNested);
  272. }
  273. else
  274. {
  275. break;
  276. }
  277. }
  278. return pos;
  279. }
  280. private ReadOnlySpan<byte> ParseAlternates(
  281. ParseState state,
  282. ReadOnlySpan<byte> src,
  283. string ruleName,
  284. uint ruleId,
  285. bool isNested)
  286. {
  287. var rule = new List<LLamaGrammarElement>();
  288. ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested);
  289. while (!pos.IsEmpty && pos[0] == '|')
  290. {
  291. rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
  292. pos = ParseSpace(pos.Slice(1), true);
  293. pos = ParseSequence(state, pos, ruleName, rule, isNested);
  294. }
  295. rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
  296. AddRule(state, ruleId, rule);
  297. return pos;
  298. }
  299. private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
  300. {
  301. ReadOnlySpan<byte> nameEnd = ParseName(src);
  302. ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
  303. int nameLen = src.Length - nameEnd.Length;
  304. uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0);
  305. string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
  306. if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
  307. {
  308. throw new GrammarFormatException($"Expecting ::= at {Encoding.UTF8.GetString(pos.ToArray())}");
  309. }
  310. pos = ParseSpace(pos.Slice(3), true);
  311. pos = ParseAlternates(state, pos, name, ruleId, false);
  312. if (!pos.IsEmpty && pos[0] == '\r')
  313. {
  314. pos = pos.Slice(pos[1] == '\n' ? 2 : 1);
  315. }
  316. else if (!pos.IsEmpty && pos[0] == '\n')
  317. {
  318. pos = pos.Slice(1);
  319. }
  320. else if (!pos.IsEmpty)
  321. {
  322. throw new GrammarFormatException($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}");
  323. }
  324. return ParseSpace(pos, true);
  325. }
  326. /// <summary>
  327. /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a>
  328. /// </summary>
  329. /// <param name="input">The string to parse</param>
  330. /// <param name="startRule">The name of the root rule of this grammar</param>
  331. /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
  332. /// <returns>A ParseState that can be converted into a grammar for sampling</returns>
  333. public Grammar Parse(string input, string startRule)
  334. {
  335. var byteArray = Encoding.UTF8.GetBytes(input);
  336. var state = new ParseState();
  337. var pos = ParseSpace(byteArray, true);
  338. while (!pos.IsEmpty)
  339. {
  340. pos = ParseRule(state, pos);
  341. }
  342. var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key);
  343. var rules = new List<GrammarRule>();
  344. for (var i = 0; i < state.Rules.Count; i++)
  345. {
  346. var elements = state.Rules[i];
  347. var name = names[(uint)i];
  348. rules.Add(new GrammarRule(name, elements));
  349. }
  350. var startRuleIndex = state.SymbolIds[startRule];
  351. return new Grammar(rules, startRuleIndex);
  352. }
  353. private record ParseState
  354. {
  355. public SortedDictionary<string, uint> SymbolIds { get; } = new();
  356. public List<List<LLamaGrammarElement>> Rules { get; } = new();
  357. }
  358. }
  359. }