From a70c7170dd55b414cdcca333f326f18e321607c0 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 31 Aug 2023 00:02:50 +0100 Subject: [PATCH] - Created a higher level `Grammar` class which is immutable and contains a list of grammar rules. This is the main "entry point" to the grammar system. - Made all the mechanics of grammar parsing (GBNFGrammarParser, ParseState) internal. Just call `Grammar.Parse("whatever")`. - Added a `GrammarRule` class which validates elements on construction (this allows constructing grammar without parsing GBNF). - It should be impossible for a `GrammarRule` to represent an invalid rule. --- .../NewVersion/GrammarJsonResponse.cs | 12 +- LLama.Unittest/GrammarParserTest.cs | 77 +++----- LLama.Unittest/GrammarTest.cs | 13 +- LLama/AssemblyAttributes.cs | 3 + LLama/Extensions/IReadOnlyListExtensions.cs | 20 ++ LLama/GlobalSuppressions.cs | 2 + LLama/Grammar/ParseState.cs | 181 ------------------ .../GBNFGrammarParser.cs} | 67 ++++--- LLama/Grammars/Grammar.cs | 151 +++++++++++++++ LLama/Grammars/GrammarRule.cs | 75 ++++++++ LLama/Native/LLamaGrammarElement.cs | 55 +++++- LLama/Native/SafeLLamaGrammarHandle.cs | 7 +- 12 files changed, 385 insertions(+), 278 deletions(-) create mode 100644 LLama/AssemblyAttributes.cs create mode 100644 LLama/Extensions/IReadOnlyListExtensions.cs delete mode 100644 LLama/Grammar/ParseState.cs rename LLama/{Grammar/GrammarParser.cs => Grammars/GBNFGrammarParser.cs} (83%) create mode 100644 LLama/Grammars/Grammar.cs create mode 100644 LLama/Grammars/GrammarRule.cs diff --git a/LLama.Examples/NewVersion/GrammarJsonResponse.cs b/LLama.Examples/NewVersion/GrammarJsonResponse.cs index 926aa82b..a3c147f5 100644 --- a/LLama.Examples/NewVersion/GrammarJsonResponse.cs +++ b/LLama.Examples/NewVersion/GrammarJsonResponse.cs @@ -1,6 +1,5 @@ using LLama.Common; -using LLama.Grammar; -using LLama.Native; +using LLama.Grammars; namespace LLama.Examples.NewVersion { @@ -8,8 +7,8 @@ namespace LLama.Examples.NewVersion { public static void Run() { - var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim(); - var parsedGrammar = new GrammarParser(); + var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); + var grammar = Grammar.Parse(gbnf, "root"); Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -22,19 +21,18 @@ namespace LLama.Examples.NewVersion }; using var model = LLamaWeights.LoadFromFile(parameters); var ex = new StatelessExecutor(model, parameters); - ParseState state = parsedGrammar.Parse(grammarBytes); - using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); Console.ForegroundColor = ConsoleColor.White; + using var grammarInstance = grammar.CreateInstance(); var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, MaxTokens = 50, - Grammar = grammar + Grammar = grammarInstance }; while (true) diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 496f502c..6d3adb82 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -1,8 +1,5 @@ -using LLama.Common; -using LLama.Native; -using System.Diagnostics; -using LLama.Grammar; -using Newtonsoft.Json.Linq; +using LLama.Native; +using LLama.Grammars; namespace LLama.Unittest { @@ -17,14 +14,15 @@ namespace LLama.Unittest [Fact] public void ParseComplexGrammar() { - GrammarParser parsedGrammar = new GrammarParser(); + GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ expr ::= term ([-+*/] term)* term ::= [0-9]+"; - ParseState state = parsedGrammar.Parse(grammarBytes); + var state = parsedGrammar.Parse(grammarBytes, "root"); + Assert.Equal(0ul, state.StartRuleIndex); - List> expected = new List> + var expected = new List> { new KeyValuePair("expr", 2), new KeyValuePair("expr_5", 5), @@ -36,27 +34,11 @@ namespace LLama.Unittest new KeyValuePair("term_7", 7), }; - uint index = 0; - foreach (var it in state.SymbolIds) + foreach (var symbol in expected) { - string key = it.Key; - uint value = it.Value; - var expectedPair = expected[(int)index]; - - // pretty print error message before asserting - if (expectedPair.Key != key || expectedPair.Value != value) - { - Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); - Console.Error.WriteLine($"actualPair: {key}, {value}"); - Console.Error.WriteLine("expectedPair != actualPair"); - } - Assert.Equal(expectedPair.Key, key); - Assert.Equal(expectedPair.Value, value); - - index++; + var rule = state.Rules[(int)symbol.Value]; + Assert.Equal(symbol.Key, rule.Name); } - Assert.NotEmpty(state.SymbolIds); - var expectedRules = new List { @@ -96,13 +78,13 @@ namespace LLama.Unittest new LLamaGrammarElement(LLamaGrammarElementType.END, 0), }; - index = 0; + uint index = 0; foreach (var rule in state.Rules) { // compare rule to expected rule - for (uint i = 0; i < rule.Count; i++) + for (uint i = 0; i < rule.Elements.Count; i++) { - var element = rule[(int)i]; + var element = rule.Elements[(int)i]; var expectedElement = expectedRules[(int)index]; // Pretty print error message before asserting @@ -124,7 +106,7 @@ namespace LLama.Unittest [Fact] public void ParseExtraComplexGrammar() { - GrammarParser parsedGrammar = new GrammarParser(); + GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); string grammarBytes = @" root ::= (expr ""="" ws term ""\n"")+ expr ::= term ([-+*/] term)* @@ -134,9 +116,10 @@ namespace LLama.Unittest ws ::= [ \t\n]* "; - ParseState state = parsedGrammar.Parse(grammarBytes); + var state = parsedGrammar.Parse(grammarBytes, "root"); + Assert.Equal(0ul, state.StartRuleIndex); - List> expected = new List> + var expected = new List> { new KeyValuePair("expr", 2), new KeyValuePair("expr_6", 6), @@ -153,27 +136,11 @@ namespace LLama.Unittest new KeyValuePair("ws_12", 12), }; - uint index = 0; - foreach (var it in state.SymbolIds) + foreach (var symbol in expected) { - string key = it.Key; - uint value = it.Value; - var expectedPair = expected[(int)index]; - - // pretty print error message before asserting - if (expectedPair.Key != key || expectedPair.Value != value) - { - Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); - Console.Error.WriteLine($"actualPair: {key}, {value}"); - Console.Error.WriteLine("expectedPair != actualPair"); - } - Assert.Equal(expectedPair.Key, key); - Assert.Equal(expectedPair.Value, value); - - index++; + var rule = state.Rules[(int)symbol.Value]; + Assert.Equal(symbol.Key, rule.Name); } - Assert.NotEmpty(state.SymbolIds); - var expectedRules = new List { @@ -246,13 +213,13 @@ namespace LLama.Unittest new LLamaGrammarElement(LLamaGrammarElementType.END, 0) }; - index = 0; + uint index = 0; foreach (var rule in state.Rules) { // compare rule to expected rule - for (uint i = 0; i < rule.Count; i++) + for (uint i = 0; i < rule.Elements.Count; i++) { - var element = rule[(int)i]; + var element = rule.Elements[(int)i]; var expectedElement = expectedRules[(int)index]; // Pretty print error message before asserting diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index 482268ea..dc2d3e3a 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -1,4 +1,5 @@ using LLama.Common; +using LLama.Grammars; using LLama.Native; namespace LLama.Unittest @@ -26,14 +27,14 @@ namespace LLama.Unittest [Fact] public void CreateBasicGrammar() { - var rules = new List> + var rules = new List { - new() + new GrammarRule("alpha", new[] { new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), new LLamaGrammarElement(LLamaGrammarElementType.END, 0), - }, + }), }; using var handle = SafeLLamaGrammarHandle.Create(rules, 0); @@ -44,15 +45,15 @@ namespace LLama.Unittest { // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so // we can be confident it's not what the LLM would say if not constrained by the grammar! - var rules = new List> + var rules = new List { - new() + new GrammarRule("feline", new [] { new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), new LLamaGrammarElement(LLamaGrammarElementType.END, 0), - }, + }), }; using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); diff --git a/LLama/AssemblyAttributes.cs b/LLama/AssemblyAttributes.cs new file mode 100644 index 00000000..dab58d12 --- /dev/null +++ b/LLama/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("LLama.Unittest")] \ No newline at end of file diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs new file mode 100644 index 00000000..51b365be --- /dev/null +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class IReadOnlyListExtensions + { + public static int? IndexOf(this IReadOnlyList list, T item) + where T : IEquatable + { + for (var i = 0; i < list.Count; i++) + { + if (list[i].Equals(item)) + return i; + } + + return null; + } + } +} diff --git a/LLama/GlobalSuppressions.cs b/LLama/GlobalSuppressions.cs index 4d4915ff..2053bc25 100644 --- a/LLama/GlobalSuppressions.cs +++ b/LLama/GlobalSuppressions.cs @@ -6,3 +6,5 @@ using System.Diagnostics.CodeAnalysis; [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] + +[assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")] diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs deleted file mode 100644 index 0c75a8a0..00000000 --- a/LLama/Grammar/ParseState.cs +++ /dev/null @@ -1,181 +0,0 @@ -using LLama.Exceptions; -using LLama.Native; -using System; -using System.Collections.Generic; -using System.IO; - -namespace LLama.Grammar -{ - /// - /// Source: - /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h - /// - /// The commit hash from URL is the actual commit hash that reflects current C# code. - /// - public class ParseState - { - public SortedDictionary SymbolIds { get; } = new SortedDictionary(); - public List> Rules { get; } = new List>(); - - public IEnumerable> CRules() - { - foreach (var rule in Rules) - { - yield return rule; - } - } - - public void PrintGrammar(StreamWriter file, ParseState state) - { - try - { - Dictionary symbolIdNames = new Dictionary(); - foreach (var kv in state.SymbolIds) - { - symbolIdNames[kv.Value] = kv.Key; - } - for (int i = 0, end = state.Rules.Count; i < end; i++) - { - PrintRule(file, (uint)i, state.Rules[i], symbolIdNames); - } - } - catch(Exception err) - { - Console.Error.WriteLine($"\nError printing grammar: {err.Message}"); - } - } - - public void PrintRuleBinary(StreamWriter file, List rule) - { - foreach (var elem in rule) - { - switch (elem.Type) - { - case LLamaGrammarElementType.END: file.Write("END"); break; - case LLamaGrammarElementType.ALT: file.Write("ALT"); break; - case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break; - case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break; - case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break; - case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break; - case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break; - } - switch (elem.Type) - { - case LLamaGrammarElementType.END: - case LLamaGrammarElementType.ALT: - case LLamaGrammarElementType.RULE_REF: - file.Write($"({elem.Value}) "); - break; - case LLamaGrammarElementType.CHAR: - case LLamaGrammarElementType.CHAR_NOT: - case LLamaGrammarElementType.CHAR_RNG_UPPER: - case LLamaGrammarElementType.CHAR_ALT: - file.Write("(\""); - PrintGrammarChar(file, elem.Value); - file.Write("\") "); - break; - } - } - file.WriteLine(); - } - - private void PrintRule( - StreamWriter file, - uint ruleId, - List rule, - Dictionary symbolIdNames) - { - if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END) - { - throw new GrammarFormatException( - $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}"); - } - - file.Write($"{symbolIdNames[ruleId]} ::= "); - - for (int i = 0, end = rule.Count - 1; i < end; i++) - { - var elem = rule[i]; - switch (elem.Type) - { - case LLamaGrammarElementType.END: - throw new GrammarFormatException( - $"Unexpected end of rule: {ruleId}, {i}"); - case LLamaGrammarElementType.ALT: - file.Write("| "); - break; - case LLamaGrammarElementType.RULE_REF: - file.Write($"{symbolIdNames[elem.Value]} "); - break; - case LLamaGrammarElementType.CHAR: - file.Write("["); - PrintGrammarChar(file, elem.Value); - break; - case LLamaGrammarElementType.CHAR_NOT: - file.Write("[^"); - PrintGrammarChar(file, elem.Value); - break; - case LLamaGrammarElementType.CHAR_RNG_UPPER: - if (i == 0 || !IsCharElement(rule[i - 1])) - { - throw new GrammarFormatException( - $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); - } - file.Write("-"); - PrintGrammarChar(file, elem.Value); - break; - case LLamaGrammarElementType.CHAR_ALT: - if (i == 0 || !IsCharElement(rule[i - 1])) - { - throw new GrammarFormatException( - $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); - } - PrintGrammarChar(file, elem.Value); - break; - - } - - if (IsCharElement(elem)) - { - switch (rule[i + 1].Type) - { - case LLamaGrammarElementType.CHAR_ALT: - case LLamaGrammarElementType.CHAR_RNG_UPPER: - break; - default: - file.Write("] "); - break; - } - } - } - file.WriteLine(); - } - - private void PrintGrammarChar(StreamWriter file, uint c) - { - if (c >= 0x20 && c <= 0x7F) - { - file.Write((char)c); - } - else - { - // cop out of encoding UTF-8 - file.Write($""); - } - } - - private bool IsCharElement(LLamaGrammarElement elem) - { - switch (elem.Type) - { - case LLamaGrammarElementType.CHAR: - case LLamaGrammarElementType.CHAR_NOT: - case LLamaGrammarElementType.CHAR_ALT: - case LLamaGrammarElementType.CHAR_RNG_UPPER: - return true; - default: - return false; - } - } - } -} diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammars/GBNFGrammarParser.cs similarity index 83% rename from LLama/Grammar/GrammarParser.cs rename to LLama/Grammars/GBNFGrammarParser.cs index 4122e58f..aec58c7a 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammars/GBNFGrammarParser.cs @@ -1,10 +1,11 @@ -using LLama.Exceptions; -using LLama.Native; -using System; +using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using LLama.Exceptions; +using LLama.Native; -namespace LLama.Grammar +namespace LLama.Grammars { /// /// Source: @@ -12,7 +13,7 @@ namespace LLama.Grammar /// /// The commit hash from URL is the actual commit hash that reflects current C# code. /// - public class GrammarParser + internal sealed class GBNFGrammarParser { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp @@ -206,7 +207,7 @@ namespace LLama.Grammar while (!pos.IsEmpty && pos[0] != '"') { var charPair = ParseChar(ref pos); - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); } pos = ParseSpace(pos.Slice(1), isNested); } @@ -228,13 +229,13 @@ namespace LLama.Grammar var charPair = ParseChar(ref pos); var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; - outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); + outElements.Add(new LLamaGrammarElement(type, charPair)); if (pos[0] == '-' && pos[1] != ']') { pos = pos.Slice(1); var endCharPair = ParseChar(ref pos); - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); } } pos = ParseSpace(pos.Slice(1), isNested); @@ -245,7 +246,7 @@ namespace LLama.Grammar uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); pos = ParseSpace(nameEnd, isNested); lastSymStart = outElements.Count; - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId }); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); } else if (pos[0] == '(') // grouping { @@ -255,7 +256,7 @@ namespace LLama.Grammar pos = ParseAlternates(state, pos, ruleName, subRuleId, true); lastSymStart = outElements.Count; // output reference to synthesized rule - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); if (pos[0] != ')') { throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); @@ -284,11 +285,11 @@ namespace LLama.Grammar if (pos[0] == '*' || pos[0] == '+') { // cause generated rule to recurse - subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); } // mark start of alternate def - subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); if (pos[0] == '+') { @@ -296,13 +297,13 @@ namespace LLama.Grammar subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); } - subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); AddRule(state, subRuleId, subRule); // in original rule, replace previous symbol with reference to generated rule outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); pos = ParseSpace(pos.Slice(1), isNested); @@ -328,12 +329,12 @@ namespace LLama.Grammar while (!pos.IsEmpty && pos[0] == '|') { - rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); pos = ParseSpace(pos.Slice(1), true); pos = ParseSequence(state, pos, ruleName, rule, isNested); } - rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); AddRule(state, ruleId, rule); return pos; @@ -370,19 +371,41 @@ namespace LLama.Grammar return ParseSpace(pos, true); } - public ParseState Parse(string input) + /// + /// Parse a string of GGML BNF + /// + /// The string to parse + /// The name of the root rule of this grammar + /// Thrown if input is malformed + /// A ParseState that can be converted into a grammar for sampling + public Grammar Parse(string input, string startRule) { - byte[] byteArray = Encoding.UTF8.GetBytes(input); - ReadOnlySpan src = new ReadOnlySpan(byteArray); - ParseState state = new ParseState(); - ReadOnlySpan pos = ParseSpace(src, true); + var byteArray = Encoding.UTF8.GetBytes(input); + var state = new ParseState(); + var pos = ParseSpace(byteArray, true); while (!pos.IsEmpty) { pos = ParseRule(state, pos); } - return state; + var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); + var rules = new List(); + for (var i = 0; i < state.Rules.Count; i++) + { + var elements = state.Rules[i]; + var name = names[(uint)i]; + rules.Add(new GrammarRule(name, elements)); + } + + var startRuleIndex = state.SymbolIds[startRule]; + return new Grammar(rules, startRuleIndex); + } + + private record ParseState + { + public SortedDictionary SymbolIds { get; } = new(); + public List> Rules { get; } = new(); } } } diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs new file mode 100644 index 00000000..f8bd9052 --- /dev/null +++ b/LLama/Grammars/Grammar.cs @@ -0,0 +1,151 @@ +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Exceptions; +using LLama.Native; + +namespace LLama.Grammars +{ + /// + /// A grammar is a set of s for deciding which characters are valid next. Can be used to constrain + /// output to certain formats - e.g. force the model to output JSON + /// + public sealed class Grammar + { + /// + /// Index of the initial rule to start from + /// + public ulong StartRuleIndex { get; set; } + + /// + /// The rules which make up this grammar + /// + public IReadOnlyList Rules { get; } + + /// + /// Create a new grammar from a set of rules + /// + /// The rules which make up this grammar + /// Index of the initial rule to start from + /// + public Grammar(IReadOnlyList rules, ulong startRuleIndex) + { + if (startRuleIndex >= (uint)rules.Count) + throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules"); + + StartRuleIndex = startRuleIndex; + Rules = rules; + } + + /// + /// Create a `SafeLLamaGrammarHandle` instance to use for parsing + /// + /// + public SafeLLamaGrammarHandle CreateInstance() + { + return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex); + } + + /// + /// Parse a string of GGML BNF into a Grammar + /// + /// The string to parse + /// Name of the start rule of this grammar + /// Thrown if input is malformed + /// A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling + public static Grammar Parse(string gbnf, string startRule) + { + var parser = new GBNFGrammarParser(); + return parser.Parse(gbnf, startRule); + } + + /// + public override string ToString() + { + var builder = new StringBuilder(); + PrintGrammar(builder); + return builder.ToString(); + } + + private void PrintGrammar(StringBuilder output) + { + for (var i = 0; i < Rules.Count; i++) + PrintRule(output, (uint)i, Rules[i]); + } + + private void PrintRule(StringBuilder output, uint ruleId, GrammarRule rule) + { + output.Append($"{rule.Name} ::= "); + + for (int i = 0, end = rule.Elements.Count - 1; i < end; i++) + { + var elem = rule.Elements[i]; + switch (elem.Type) + { + case LLamaGrammarElementType.END: + throw new GrammarFormatException($"Unexpected end of rule: {ruleId}, {i}"); + case LLamaGrammarElementType.ALT: + output.Append("| "); + break; + case LLamaGrammarElementType.RULE_REF: + output.Append($"{Rules[(int)elem.Value].Name} "); + break; + case LLamaGrammarElementType.CHAR: + output.Append('['); + PrintGrammarChar(output, elem.Value); + break; + case LLamaGrammarElementType.CHAR_NOT: + output.Append("[^"); + PrintGrammarChar(output, elem.Value); + break; + case LLamaGrammarElementType.CHAR_RNG_UPPER: + if (i == 0 || !rule.Elements[i - 1].IsCharElement()) + { + throw new GrammarFormatException( + $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); + } + output.Append('-'); + PrintGrammarChar(output, elem.Value); + break; + case LLamaGrammarElementType.CHAR_ALT: + if (i == 0 || !rule.Elements[i - 1].IsCharElement()) + { + throw new GrammarFormatException( + $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); + } + PrintGrammarChar(output, elem.Value); + break; + + } + + if (elem.IsCharElement()) + { + switch (rule.Elements[i + 1].Type) + { + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + break; + default: + output.Append("] "); + break; + } + } + } + + output.AppendLine(); + } + + private static void PrintGrammarChar(StringBuilder output, uint c) + { + if (c >= 0x20 && c <= 0x7F) + { + output.Append((char)c); + } + else + { + // cop out of encoding UTF-8 + output.Append($""); + } + } + } +} diff --git a/LLama/Grammars/GrammarRule.cs b/LLama/Grammars/GrammarRule.cs new file mode 100644 index 00000000..beab1078 --- /dev/null +++ b/LLama/Grammars/GrammarRule.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Grammars +{ + /// + /// A single rule in a + /// + public sealed record GrammarRule + { + /// + /// Name of this rule + /// + public string Name { get; } + + /// + /// The elements of this grammar rule + /// + public IReadOnlyList Elements { get; } + + /// + /// Create a new GrammarRule containing the given elements + /// + /// + /// + /// + public GrammarRule(string name, IReadOnlyList elements) + { + Validate(elements, name); + + Name = name; + Elements = elements; + } + + private static void Validate(IReadOnlyList elements, string name) + { + if (elements.Count == 0) + throw new ArgumentException("Cannot create a GrammaRule with zero elements", nameof(elements)); + if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END) + throw new ArgumentException("Last grammar element must be END", nameof(elements)); + + for (var i = 0; i < elements.Count; i++) + { + switch (elements[i].Type) + { + case LLamaGrammarElementType.END: + if (i != elements.Count - 1) + throw new ArgumentException("Found more than one END grammar element", nameof(elements)); + continue; + + case LLamaGrammarElementType.CHAR_RNG_UPPER: + if (i == 0 || !elements[i - 1].IsCharElement()) + throw new ArgumentException($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {name},{i}", nameof(elements)); + break; + case LLamaGrammarElementType.CHAR_ALT: + if (i == 0 || !elements[i - 1].IsCharElement()) + { + throw new ArgumentException($"LLamaGrammarElementType.CHAR_ALT without preceding char: {name},{i}", nameof(elements)); + } + break; + + case LLamaGrammarElementType.ALT: + case LLamaGrammarElementType.RULE_REF: + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + break; + + default: + throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'"); + } + } + } + } +} diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index 7c321c5d..688f5ccb 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -1,4 +1,5 @@ -using System.Diagnostics; +using System; +using System.Diagnostics; using System.Runtime.InteropServices; namespace LLama.Native @@ -51,17 +52,18 @@ namespace LLama.Native /// [StructLayout(LayoutKind.Sequential)] [DebuggerDisplay("{Type} {Value}")] - public struct LLamaGrammarElement + public readonly struct LLamaGrammarElement + : IEquatable { /// /// The type of this element /// - public LLamaGrammarElementType Type; + public readonly LLamaGrammarElementType Type; /// /// Unicode code point or rule ID /// - public uint Value; + public readonly uint Value; /// /// Construct a new LLamaGrammarElement @@ -73,5 +75,50 @@ namespace LLama.Native Type = type; Value = value; } + + /// + public bool Equals(LLamaGrammarElement other) + { + if (Type != other.Type) + return false; + + // No need to compare values for the END rule + if (Type == LLamaGrammarElementType.END) + return true; + + return Value == other.Value; + } + + /// + public override bool Equals(object? obj) + { + return obj is LLamaGrammarElement other && Equals(other); + } + + /// + public override int GetHashCode() + { + unchecked + { + var hash = 2999; + hash = hash * 7723 + (int)Type; + hash = hash * 7723 + (int)Value; + return hash; + } + } + + internal bool IsCharElement() + { + switch (Type) + { + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + return true; + default: + return false; + } + } } } diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index 0b4eda9d..ed1c15c8 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using LLama.Exceptions; +using LLama.Grammars; namespace LLama.Native { @@ -38,11 +39,11 @@ namespace LLama.Native /// The index (in the outer list) of the start rule /// /// - public static SafeLLamaGrammarHandle Create(IReadOnlyList> rules, ulong start_rule_index) + public static SafeLLamaGrammarHandle Create(IReadOnlyList rules, ulong start_rule_index) { unsafe { - var totalElements = rules.Sum(a => a.Count); + var totalElements = rules.Sum(a => a.Elements.Count); var nrules = (ulong)rules.Count; // Borrow an array large enough to hold every single element @@ -61,7 +62,7 @@ namespace LLama.Native pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); // Copy all of the rule elements into the flat array - foreach (var element in rule) + foreach (var element in rule.Elements) allElementsPtr[elementIndex++] = element; }