using System; using System.Collections.Generic; using System.Linq; using System.Text; using LLama.Exceptions; using LLama.Native; namespace LLama.Grammars { /// /// Source: /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp /// /// The commit hash from URL is the actual commit hash that reflects current C# code. /// internal sealed class GBNFGrammarParser { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp private static uint DecodeUTF8(ref ReadOnlySpan src) { int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; byte firstByte = src[0]; byte highbits = (byte)(firstByte >> 4); int len = lookup[highbits]; byte mask = (byte)((1 << (8 - len)) - 1); uint value = (uint)(firstByte & mask); int end = len; int pos = 1; for (; pos < end && pos < src.Length; pos++) { value = (uint)((value << 6) + (src[pos] & 0x3F)); } src = src.Slice(pos); return value; } private static bool IsWordChar(byte c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } private static uint ParseHex(ref ReadOnlySpan src, int size) { int pos = 0; int end = size; uint value = 0; for (; pos < end && pos < src.Length; pos++) { value <<= 4; byte c = src[pos]; if ('a' <= c && c <= 'f') { value += (uint)(c - 'a' + 10); } else if ('A' <= c && c <= 'F') { value += (uint)(c - 'A' + 10); } else if ('0' <= c && c <= '9') { value += (uint)(c - '0'); } else { break; } } if (pos != end) { throw new GrammarUnexpectedHexCharsCount(size, Encoding.UTF8.GetString(src.ToArray())); } src = src.Slice(pos); return value; } private static ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) { int pos = 0; while (pos < src.Length && (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' || (newlineOk && (src[pos] == '\r' || src[pos] == '\n')))) { if (src[pos] == '#') { while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n') { pos++; } } else { pos++; } } return src.Slice(pos); } private static ReadOnlySpan ParseName(ReadOnlySpan src) { int pos = 0; while (pos < src.Length && IsWordChar(src[pos])) { pos++; } if (pos == 0) { throw new GrammarExpectedName(Encoding.UTF8.GetString(src.ToArray())); } return src.Slice(pos); } private static uint ParseChar(ref ReadOnlySpan src) { if (src[0] == '\\') { if (src.Length < 2) throw new GrammarUnexpectedEndOfInput(); var chr = src[1]; src = src.Slice(2); return (char)chr switch { 'x' => ParseHex(ref src, 2), 'u' => ParseHex(ref src, 4), 'U' => ParseHex(ref src, 8), 't' => '\t', 'r' => '\r', 'n' => '\n', '\\' or '"' or '[' or ']' => chr, _ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())), }; } if (!src.IsEmpty) return DecodeUTF8(ref src); throw new GrammarUnexpectedEndOfInput(); } private ReadOnlySpan ParseSequence( ParseState state, ReadOnlySpan pos, string ruleName, List outElements, bool isNested) { int lastSymStart = outElements.Count; while (!pos.IsEmpty) { if (pos[0] == '"') // literal string { pos = pos.Slice(1); lastSymStart = outElements.Count; while (!pos.IsEmpty && pos[0] != '"') { var charPair = ParseChar(ref pos); outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); } pos = ParseSpace(pos.Slice(1), isNested); } else if (pos[0] == '[') // char range(s) { pos = pos.Slice(1); var startType = LLamaGrammarElementType.CHAR; if (pos[0] == '^') { pos = pos.Slice(1); startType = LLamaGrammarElementType.CHAR_NOT; } lastSymStart = outElements.Count; while (!pos.IsEmpty && pos[0] != ']') { var charPair = ParseChar(ref pos); var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; outElements.Add(new LLamaGrammarElement(type, charPair)); if (pos[0] == '-' && pos[1] != ']') { pos = pos.Slice(1); var endCharPair = ParseChar(ref pos); outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); } } pos = ParseSpace(pos.Slice(1), isNested); } else if (IsWordChar(pos[0])) // rule reference { var nameEnd = ParseName(pos); uint refRuleId = state.GetSymbolId(pos, nameEnd.Length); pos = ParseSpace(nameEnd, isNested); lastSymStart = outElements.Count; outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); } else if (pos[0] == '(') // grouping { // parse nested alternates into synthesized rule pos = ParseSpace(pos.Slice(1), true); uint subRuleId = state.GenerateSymbolId(ruleName); pos = ParseAlternates(state, pos, ruleName, subRuleId, true); lastSymStart = outElements.Count; // output reference to synthesized rule outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); if (pos[0] != ')') throw new GrammarExpectedNext(")", Encoding.UTF8.GetString(pos.ToArray())); pos = ParseSpace(pos.Slice(1), isNested); } else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator { if (lastSymStart == outElements.Count) throw new GrammarExpectedPrevious("*/+/?", Encoding.UTF8.GetString(pos.ToArray())); // apply transformation to previous symbol (lastSymStart to end) according to // rewrite rules: // S* --> S' ::= S S' | // S+ --> S' ::= S S' | S // S? --> S' ::= S | uint subRuleId = state.GenerateSymbolId(ruleName); List subRule = new List(); // add preceding symbol to generated rule subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); if (pos[0] == '*' || pos[0] == '+') { // cause generated rule to recurse subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); } // mark start of alternate def subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); if (pos[0] == '+') { // add preceding symbol as alternate only for '+' (otherwise empty) subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); } subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); state.AddRule(subRuleId, subRule); // in original rule, replace previous symbol with reference to generated rule outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); pos = ParseSpace(pos.Slice(1), isNested); } else { break; } } return pos; } private ReadOnlySpan ParseAlternates( ParseState state, ReadOnlySpan src, string ruleName, uint ruleId, bool isNested) { var rule = new List(); ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested); while (!pos.IsEmpty && pos[0] == '|') { rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); pos = ParseSpace(pos.Slice(1), true); pos = ParseSequence(state, pos, ruleName, rule, isNested); } rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); state.AddRule(ruleId, rule); return pos; } private ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src) { ReadOnlySpan nameEnd = ParseName(src); ReadOnlySpan pos = ParseSpace(nameEnd, false); int nameLen = src.Length - nameEnd.Length; uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0); string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) throw new GrammarExpectedNext("::=", Encoding.UTF8.GetString(pos.ToArray())); pos = ParseSpace(pos.Slice(3), true); pos = ParseAlternates(state, pos, name, ruleId, false); if (!pos.IsEmpty && pos[0] == '\r') { pos = pos.Slice(pos[1] == '\n' ? 2 : 1); } else if (!pos.IsEmpty && pos[0] == '\n') { pos = pos.Slice(1); } else if (!pos.IsEmpty) { throw new GrammarExpectedNext("newline or EOF", Encoding.UTF8.GetString(pos.ToArray())); } return ParseSpace(pos, true); } /// /// Parse a string of GGML BNF /// /// The string to parse /// The name of the root rule of this grammar /// Thrown if input is malformed /// A ParseState that can be converted into a grammar for sampling public Grammar Parse(string input, string startRule) { var byteArray = Encoding.UTF8.GetBytes(input); var state = new ParseState(); var pos = ParseSpace(byteArray, true); while (!pos.IsEmpty) { pos = ParseRule(state, pos); } var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); var rules = new List(); for (var i = 0; i < state.Rules.Count; i++) { var elements = state.Rules[i]; var name = names[(uint)i]; rules.Add(new GrammarRule(name, elements)); } var startRuleIndex = state.SymbolIds[startRule]; return new Grammar(rules, startRuleIndex); } private record ParseState { public SortedDictionary SymbolIds { get; } = new(); public List> Rules { get; } = new(); public uint GetSymbolId(ReadOnlySpan src, int len) { var nextId = (uint)SymbolIds.Count; var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); if (SymbolIds.TryGetValue(key, out uint existingId)) { return existingId; } else { SymbolIds[key] = nextId; return nextId; } } public uint GenerateSymbolId(string baseName) { var nextId = (uint)SymbolIds.Count; var key = $"{baseName}_{nextId}"; SymbolIds[key] = nextId; return nextId; } public void AddRule(uint ruleId, List rule) { while (Rules.Count <= ruleId) { Rules.Add(new List()); } Rules[(int)ruleId] = rule; } } } }