You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GBNFGrammarParser.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using LLama.Native;
  7. namespace LLama.Grammars
  8. {
  9. /// <summary>
  10. /// Source:
  11. /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp
  12. ///
  13. /// The commit hash from URL is the actual commit hash that reflects current C# code.
  14. /// </summary>
  15. internal sealed class GBNFGrammarParser
  16. {
  17. // NOTE: assumes valid utf8 (but checks for overrun)
  18. // copied from llama.cpp
  19. private static uint DecodeUTF8(ref ReadOnlySpan<byte> src)
  20. {
  21. int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  22. byte firstByte = src[0];
  23. byte highbits = (byte)(firstByte >> 4);
  24. int len = lookup[highbits];
  25. byte mask = (byte)((1 << (8 - len)) - 1);
  26. uint value = (uint)(firstByte & mask);
  27. int end = len;
  28. int pos = 1;
  29. for (; pos < end && pos < src.Length; pos++)
  30. {
  31. value = (uint)((value << 6) + (src[pos] & 0x3F));
  32. }
  33. src = src.Slice(pos);
  34. return value;
  35. }
  36. private static bool IsWordChar(byte c)
  37. {
  38. return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
  39. }
  40. private static uint ParseHex(ref ReadOnlySpan<byte> src, int size)
  41. {
  42. int pos = 0;
  43. int end = size;
  44. uint value = 0;
  45. for (; pos < end && pos < src.Length; pos++)
  46. {
  47. value <<= 4;
  48. byte c = src[pos];
  49. if ('a' <= c && c <= 'f')
  50. {
  51. value += (uint)(c - 'a' + 10);
  52. }
  53. else if ('A' <= c && c <= 'F')
  54. {
  55. value += (uint)(c - 'A' + 10);
  56. }
  57. else if ('0' <= c && c <= '9')
  58. {
  59. value += (uint)(c - '0');
  60. }
  61. else
  62. {
  63. break;
  64. }
  65. }
  66. if (pos != end)
  67. {
  68. throw new GrammarUnexpectedHexCharsCount(size, Encoding.UTF8.GetString(src.ToArray()));
  69. }
  70. src = src.Slice(pos);
  71. return value;
  72. }
  73. private static ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
  74. {
  75. int pos = 0;
  76. while (pos < src.Length &&
  77. (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' ||
  78. (newlineOk && (src[pos] == '\r' || src[pos] == '\n'))))
  79. {
  80. if (src[pos] == '#')
  81. {
  82. while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n')
  83. {
  84. pos++;
  85. }
  86. }
  87. else
  88. {
  89. pos++;
  90. }
  91. }
  92. return src.Slice(pos);
  93. }
  94. private static ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
  95. {
  96. int pos = 0;
  97. while (pos < src.Length && IsWordChar(src[pos]))
  98. {
  99. pos++;
  100. }
  101. if (pos == 0)
  102. {
  103. throw new GrammarExpectedName(Encoding.UTF8.GetString(src.ToArray()));
  104. }
  105. return src.Slice(pos);
  106. }
  107. private static uint ParseChar(ref ReadOnlySpan<byte> src)
  108. {
  109. if (src[0] == '\\')
  110. {
  111. if (src.Length < 2)
  112. throw new GrammarUnexpectedEndOfInput();
  113. var chr = src[1];
  114. src = src.Slice(2);
  115. return (char)chr switch
  116. {
  117. 'x' => ParseHex(ref src, 2),
  118. 'u' => ParseHex(ref src, 4),
  119. 'U' => ParseHex(ref src, 8),
  120. 't' => '\t',
  121. 'r' => '\r',
  122. 'n' => '\n',
  123. '\\' or '"' or '[' or ']' => chr,
  124. _ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())),
  125. };
  126. }
  127. if (!src.IsEmpty)
  128. return DecodeUTF8(ref src);
  129. throw new GrammarUnexpectedEndOfInput();
  130. }
  131. private ReadOnlySpan<byte> ParseSequence(
  132. ParseState state,
  133. ReadOnlySpan<byte> pos,
  134. string ruleName,
  135. List<LLamaGrammarElement> outElements,
  136. bool isNested)
  137. {
  138. int lastSymStart = outElements.Count;
  139. while (!pos.IsEmpty)
  140. {
  141. if (pos[0] == '"') // literal string
  142. {
  143. pos = pos.Slice(1);
  144. lastSymStart = outElements.Count;
  145. while (!pos.IsEmpty && pos[0] != '"')
  146. {
  147. var charPair = ParseChar(ref pos);
  148. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair));
  149. }
  150. pos = ParseSpace(pos.Slice(1), isNested);
  151. }
  152. else if (pos[0] == '[') // char range(s)
  153. {
  154. pos = pos.Slice(1);
  155. var startType = LLamaGrammarElementType.CHAR;
  156. if (pos[0] == '^')
  157. {
  158. pos = pos.Slice(1);
  159. startType = LLamaGrammarElementType.CHAR_NOT;
  160. }
  161. lastSymStart = outElements.Count;
  162. while (!pos.IsEmpty && pos[0] != ']')
  163. {
  164. var charPair = ParseChar(ref pos);
  165. var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;
  166. outElements.Add(new LLamaGrammarElement(type, charPair));
  167. if (pos[0] == '-' && pos[1] != ']')
  168. {
  169. pos = pos.Slice(1);
  170. var endCharPair = ParseChar(ref pos);
  171. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair));
  172. }
  173. }
  174. pos = ParseSpace(pos.Slice(1), isNested);
  175. }
  176. else if (IsWordChar(pos[0])) // rule reference
  177. {
  178. var nameEnd = ParseName(pos);
  179. uint refRuleId = state.GetSymbolId(pos, nameEnd.Length);
  180. pos = ParseSpace(nameEnd, isNested);
  181. lastSymStart = outElements.Count;
  182. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
  183. }
  184. else if (pos[0] == '(') // grouping
  185. {
  186. // parse nested alternates into synthesized rule
  187. pos = ParseSpace(pos.Slice(1), true);
  188. uint subRuleId = state.GenerateSymbolId(ruleName);
  189. pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
  190. lastSymStart = outElements.Count;
  191. // output reference to synthesized rule
  192. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  193. if (pos[0] != ')')
  194. throw new GrammarExpectedNext(")", Encoding.UTF8.GetString(pos.ToArray()));
  195. pos = ParseSpace(pos.Slice(1), isNested);
  196. }
  197. else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
  198. {
  199. if (lastSymStart == outElements.Count)
  200. throw new GrammarExpectedPrevious("*/+/?", Encoding.UTF8.GetString(pos.ToArray()));
  201. // apply transformation to previous symbol (lastSymStart to end) according to
  202. // rewrite rules:
  203. // S* --> S' ::= S S' |
  204. // S+ --> S' ::= S S' | S
  205. // S? --> S' ::= S |
  206. uint subRuleId = state.GenerateSymbolId(ruleName);
  207. List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();
  208. // add preceding symbol to generated rule
  209. subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
  210. if (pos[0] == '*' || pos[0] == '+')
  211. {
  212. // cause generated rule to recurse
  213. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  214. }
  215. // mark start of alternate def
  216. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
  217. if (pos[0] == '+')
  218. {
  219. // add preceding symbol as alternate only for '+' (otherwise empty)
  220. subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
  221. }
  222. subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
  223. state.AddRule(subRuleId, subRule);
  224. // in original rule, replace previous symbol with reference to generated rule
  225. outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
  226. outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
  227. pos = ParseSpace(pos.Slice(1), isNested);
  228. }
  229. else
  230. {
  231. break;
  232. }
  233. }
  234. return pos;
  235. }
  236. private ReadOnlySpan<byte> ParseAlternates(
  237. ParseState state,
  238. ReadOnlySpan<byte> src,
  239. string ruleName,
  240. uint ruleId,
  241. bool isNested)
  242. {
  243. var rule = new List<LLamaGrammarElement>();
  244. ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested);
  245. while (!pos.IsEmpty && pos[0] == '|')
  246. {
  247. rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
  248. pos = ParseSpace(pos.Slice(1), true);
  249. pos = ParseSequence(state, pos, ruleName, rule, isNested);
  250. }
  251. rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
  252. state.AddRule(ruleId, rule);
  253. return pos;
  254. }
  255. private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
  256. {
  257. ReadOnlySpan<byte> nameEnd = ParseName(src);
  258. ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
  259. int nameLen = src.Length - nameEnd.Length;
  260. uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0);
  261. string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
  262. if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
  263. throw new GrammarExpectedNext("::=", Encoding.UTF8.GetString(pos.ToArray()));
  264. pos = ParseSpace(pos.Slice(3), true);
  265. pos = ParseAlternates(state, pos, name, ruleId, false);
  266. if (!pos.IsEmpty && pos[0] == '\r')
  267. {
  268. pos = pos.Slice(pos[1] == '\n' ? 2 : 1);
  269. }
  270. else if (!pos.IsEmpty && pos[0] == '\n')
  271. {
  272. pos = pos.Slice(1);
  273. }
  274. else if (!pos.IsEmpty)
  275. {
  276. throw new GrammarExpectedNext("newline or EOF", Encoding.UTF8.GetString(pos.ToArray()));
  277. }
  278. return ParseSpace(pos, true);
  279. }
  280. /// <summary>
  281. /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a>
  282. /// </summary>
  283. /// <param name="input">The string to parse</param>
  284. /// <param name="startRule">The name of the root rule of this grammar</param>
  285. /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
  286. /// <returns>A ParseState that can be converted into a grammar for sampling</returns>
  287. public Grammar Parse(string input, string startRule)
  288. {
  289. var byteArray = Encoding.UTF8.GetBytes(input);
  290. var state = new ParseState();
  291. var pos = ParseSpace(byteArray, true);
  292. while (!pos.IsEmpty)
  293. {
  294. pos = ParseRule(state, pos);
  295. }
  296. var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key);
  297. var rules = new List<GrammarRule>();
  298. for (var i = 0; i < state.Rules.Count; i++)
  299. {
  300. var elements = state.Rules[i];
  301. var name = names[(uint)i];
  302. rules.Add(new GrammarRule(name, elements));
  303. }
  304. var startRuleIndex = state.SymbolIds[startRule];
  305. return new Grammar(rules, startRuleIndex);
  306. }
  307. private record ParseState
  308. {
  309. public SortedDictionary<string, uint> SymbolIds { get; } = new();
  310. public List<List<LLamaGrammarElement>> Rules { get; } = new();
  311. public uint GetSymbolId(ReadOnlySpan<byte> src, int len)
  312. {
  313. var nextId = (uint)SymbolIds.Count;
  314. var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
  315. if (SymbolIds.TryGetValue(key, out uint existingId))
  316. {
  317. return existingId;
  318. }
  319. else
  320. {
  321. SymbolIds[key] = nextId;
  322. return nextId;
  323. }
  324. }
  325. public uint GenerateSymbolId(string baseName)
  326. {
  327. var nextId = (uint)SymbolIds.Count;
  328. var key = $"{baseName}_{nextId}";
  329. SymbolIds[key] = nextId;
  330. return nextId;
  331. }
  332. public void AddRule(uint ruleId, List<LLamaGrammarElement> rule)
  333. {
  334. while (Rules.Count <= ruleId)
  335. {
  336. Rules.Add(new List<LLamaGrammarElement>());
  337. }
  338. Rules[(int)ruleId] = rule;
  339. }
  340. }
  341. }
  342. }