- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using LLama.Exceptions;
- using LLama.Native;
-
- namespace LLama.Grammars
- {
- /// <summary>
- /// Source:
- /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp
- ///
- /// The commit hash from URL is the actual commit hash that reflects current C# code.
- /// </summary>
- internal sealed class GBNFGrammarParser
- {
- // NOTE: assumes valid utf8 (but checks for overrun)
- // copied from llama.cpp
- private static uint DecodeUTF8(ref ReadOnlySpan<byte> src)
- {
- int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
-
- byte firstByte = src[0];
- byte highbits = (byte)(firstByte >> 4);
- int len = lookup[highbits];
- byte mask = (byte)((1 << (8 - len)) - 1);
- uint value = (uint)(firstByte & mask);
-
- int end = len;
- int pos = 1;
-
- for (; pos < end && pos < src.Length; pos++)
- {
- value = (uint)((value << 6) + (src[pos] & 0x3F));
- }
-
- src = src.Slice(pos);
-
- return value;
- }
-
- private static bool IsWordChar(byte c)
- {
- return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
- }
-
- private static uint ParseHex(ref ReadOnlySpan<byte> src, int size)
- {
- int pos = 0;
- int end = size;
- uint value = 0;
-
- for (; pos < end && pos < src.Length; pos++)
- {
- value <<= 4;
- byte c = src[pos];
- if ('a' <= c && c <= 'f')
- {
- value += (uint)(c - 'a' + 10);
- }
- else if ('A' <= c && c <= 'F')
- {
- value += (uint)(c - 'A' + 10);
- }
- else if ('0' <= c && c <= '9')
- {
- value += (uint)(c - '0');
- }
- else
- {
- break;
- }
- }
-
- if (pos != end)
- {
- throw new GrammarUnexpectedHexCharsCount(size, Encoding.UTF8.GetString(src.ToArray()));
- }
- src = src.Slice(pos);
- return value;
- }
-
- private static ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
- {
- int pos = 0;
- while (pos < src.Length &&
- (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' ||
- (newlineOk && (src[pos] == '\r' || src[pos] == '\n'))))
- {
- if (src[pos] == '#')
- {
- while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n')
- {
- pos++;
- }
- }
- else
- {
- pos++;
- }
- }
- return src.Slice(pos);
- }
-
- private static ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
- {
- int pos = 0;
- while (pos < src.Length && IsWordChar(src[pos]))
- {
- pos++;
- }
- if (pos == 0)
- {
- throw new GrammarExpectedName(Encoding.UTF8.GetString(src.ToArray()));
- }
- return src.Slice(pos);
- }
-
- private static uint ParseChar(ref ReadOnlySpan<byte> src)
- {
- if (src[0] == '\\')
- {
- if (src.Length < 2)
- throw new GrammarUnexpectedEndOfInput();
-
- var chr = src[1];
- src = src.Slice(2);
-
- return (char)chr switch
- {
- 'x' => ParseHex(ref src, 2),
- 'u' => ParseHex(ref src, 4),
- 'U' => ParseHex(ref src, 8),
- 't' => '\t',
- 'r' => '\r',
- 'n' => '\n',
- '\\' or '"' or '[' or ']' => chr,
- _ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())),
- };
- }
-
- if (!src.IsEmpty)
- return DecodeUTF8(ref src);
-
- throw new GrammarUnexpectedEndOfInput();
- }
-
- private ReadOnlySpan<byte> ParseSequence(
- ParseState state,
- ReadOnlySpan<byte> pos,
- string ruleName,
- List<LLamaGrammarElement> outElements,
- bool isNested)
- {
- int lastSymStart = outElements.Count;
-
- while (!pos.IsEmpty)
- {
- if (pos[0] == '"') // literal string
- {
- pos = pos.Slice(1);
- lastSymStart = outElements.Count;
-
- while (!pos.IsEmpty && pos[0] != '"')
- {
- var charPair = ParseChar(ref pos);
- outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair));
- }
- pos = ParseSpace(pos.Slice(1), isNested);
- }
- else if (pos[0] == '[') // char range(s)
- {
- pos = pos.Slice(1);
- var startType = LLamaGrammarElementType.CHAR;
-
- if (pos[0] == '^')
- {
- pos = pos.Slice(1);
- startType = LLamaGrammarElementType.CHAR_NOT;
- }
-
- lastSymStart = outElements.Count;
-
- while (!pos.IsEmpty && pos[0] != ']')
- {
- var charPair = ParseChar(ref pos);
- var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;
-
- outElements.Add(new LLamaGrammarElement(type, charPair));
-
- if (pos[0] == '-' && pos[1] != ']')
- {
- pos = pos.Slice(1);
- var endCharPair = ParseChar(ref pos);
- outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair));
- }
- }
- pos = ParseSpace(pos.Slice(1), isNested);
- }
- else if (IsWordChar(pos[0])) // rule reference
- {
- var nameEnd = ParseName(pos);
- uint refRuleId = state.GetSymbolId(pos, nameEnd.Length);
- pos = ParseSpace(nameEnd, isNested);
- lastSymStart = outElements.Count;
- outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
- }
- else if (pos[0] == '(') // grouping
- {
- // parse nested alternates into synthesized rule
- pos = ParseSpace(pos.Slice(1), true);
- uint subRuleId = state.GenerateSymbolId(ruleName);
- pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
- lastSymStart = outElements.Count;
- // output reference to synthesized rule
- outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
- if (pos[0] != ')')
- throw new GrammarExpectedNext(")", Encoding.UTF8.GetString(pos.ToArray()));
- pos = ParseSpace(pos.Slice(1), isNested);
- }
- else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
- {
- if (lastSymStart == outElements.Count)
- throw new GrammarExpectedPrevious("*/+/?", Encoding.UTF8.GetString(pos.ToArray()));
-
- // apply transformation to previous symbol (lastSymStart to end) according to
- // rewrite rules:
- // S* --> S' ::= S S' |
- // S+ --> S' ::= S S' | S
- // S? --> S' ::= S |
- uint subRuleId = state.GenerateSymbolId(ruleName);
-
- List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();
-
- // add preceding symbol to generated rule
- subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
-
- if (pos[0] == '*' || pos[0] == '+')
- {
- // cause generated rule to recurse
- subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
- }
-
- // mark start of alternate def
- subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
-
- if (pos[0] == '+')
- {
- // add preceding symbol as alternate only for '+' (otherwise empty)
- subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
- }
-
- subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
-
- state.AddRule(subRuleId, subRule);
-
- // in original rule, replace previous symbol with reference to generated rule
- outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
- outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
-
- pos = ParseSpace(pos.Slice(1), isNested);
-
- }
- else
- {
- break;
- }
- }
-
- return pos;
- }
-
- private ReadOnlySpan<byte> ParseAlternates(
- ParseState state,
- ReadOnlySpan<byte> src,
- string ruleName,
- uint ruleId,
- bool isNested)
- {
- var rule = new List<LLamaGrammarElement>();
- ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested);
-
- while (!pos.IsEmpty && pos[0] == '|')
- {
- rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
- pos = ParseSpace(pos.Slice(1), true);
- pos = ParseSequence(state, pos, ruleName, rule, isNested);
- }
-
- rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
- state.AddRule(ruleId, rule);
-
- return pos;
- }
-
- private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
- {
- ReadOnlySpan<byte> nameEnd = ParseName(src);
- ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
- int nameLen = src.Length - nameEnd.Length;
- uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0);
- string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
-
- if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
- throw new GrammarExpectedNext("::=", Encoding.UTF8.GetString(pos.ToArray()));
-
- pos = ParseSpace(pos.Slice(3), true);
-
- pos = ParseAlternates(state, pos, name, ruleId, false);
-
- if (!pos.IsEmpty && pos[0] == '\r')
- {
- pos = pos.Slice(pos[1] == '\n' ? 2 : 1);
- }
- else if (!pos.IsEmpty && pos[0] == '\n')
- {
- pos = pos.Slice(1);
- }
- else if (!pos.IsEmpty)
- {
- throw new GrammarExpectedNext("newline or EOF", Encoding.UTF8.GetString(pos.ToArray()));
- }
- return ParseSpace(pos, true);
- }
-
- /// <summary>
- /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a>
- /// </summary>
- /// <param name="input">The string to parse</param>
- /// <param name="startRule">The name of the root rule of this grammar</param>
- /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
- /// <returns>A ParseState that can be converted into a grammar for sampling</returns>
- public Grammar Parse(string input, string startRule)
- {
- var byteArray = Encoding.UTF8.GetBytes(input);
- var state = new ParseState();
- var pos = ParseSpace(byteArray, true);
-
- while (!pos.IsEmpty)
- {
- pos = ParseRule(state, pos);
- }
-
- var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key);
- var rules = new List<GrammarRule>();
- for (var i = 0; i < state.Rules.Count; i++)
- {
- var elements = state.Rules[i];
- var name = names[(uint)i];
- rules.Add(new GrammarRule(name, elements));
- }
-
- var startRuleIndex = state.SymbolIds[startRule];
- return new Grammar(rules, startRuleIndex);
- }
-
- private record ParseState
- {
- public SortedDictionary<string, uint> SymbolIds { get; } = new();
- public List<List<LLamaGrammarElement>> Rules { get; } = new();
-
- public uint GetSymbolId(ReadOnlySpan<byte> src, int len)
- {
- var nextId = (uint)SymbolIds.Count;
- var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
-
- if (SymbolIds.TryGetValue(key, out uint existingId))
- {
- return existingId;
- }
- else
- {
- SymbolIds[key] = nextId;
- return nextId;
- }
- }
-
- public uint GenerateSymbolId(string baseName)
- {
- var nextId = (uint)SymbolIds.Count;
- var key = $"{baseName}_{nextId}";
- SymbolIds[key] = nextId;
- return nextId;
- }
-
- public void AddRule(uint ruleId, List<LLamaGrammarElement> rule)
- {
- while (Rules.Count <= ruleId)
- {
- Rules.Add(new List<LLamaGrammarElement>());
- }
-
- Rules[(int)ruleId] = rule;
- }
- }
- }
- }
|