- Made all the mechanics of grammar parsing (GBNFGrammarParser, ParseState) internal. Just call `Grammar.Parse("whatever")`.
- Added a `GrammarRule` class which validates elements on construction (this allows constructing grammar without parsing GBNF).
- It should be impossible for a `GrammarRule` to represent an invalid rule.
tags/v0.5.1
| @@ -1,6 +1,5 @@ | |||
| using LLama.Common; | |||
| using LLama.Grammar; | |||
| using LLama.Native; | |||
| using LLama.Grammars; | |||
| namespace LLama.Examples.NewVersion | |||
| { | |||
| @@ -8,8 +7,8 @@ namespace LLama.Examples.NewVersion | |||
| { | |||
| public static void Run() | |||
| { | |||
| var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim(); | |||
| var parsedGrammar = new GrammarParser(); | |||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||
| var grammar = Grammar.Parse(gbnf, "root"); | |||
| Console.Write("Please input your model path: "); | |||
| var modelPath = Console.ReadLine(); | |||
| @@ -22,19 +21,18 @@ namespace LLama.Examples.NewVersion | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| var ex = new StatelessExecutor(model, parameters); | |||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||
| using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| using var grammarInstance = grammar.CreateInstance(); | |||
| var inferenceParams = new InferenceParams() | |||
| { | |||
| Temperature = 0.6f, | |||
| AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, | |||
| MaxTokens = 50, | |||
| Grammar = grammar | |||
| Grammar = grammarInstance | |||
| }; | |||
| while (true) | |||
| @@ -1,8 +1,5 @@ | |||
| using LLama.Common; | |||
| using LLama.Native; | |||
| using System.Diagnostics; | |||
| using LLama.Grammar; | |||
| using Newtonsoft.Json.Linq; | |||
| using LLama.Native; | |||
| using LLama.Grammars; | |||
| namespace LLama.Unittest | |||
| { | |||
| @@ -17,14 +14,15 @@ namespace LLama.Unittest | |||
| [Fact] | |||
| public void ParseComplexGrammar() | |||
| { | |||
| GrammarParser parsedGrammar = new GrammarParser(); | |||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||
| string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ | |||
| expr ::= term ([-+*/] term)* | |||
| term ::= [0-9]+"; | |||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||
| Assert.Equal(0ul, state.StartRuleIndex); | |||
| List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | |||
| var expected = new List<KeyValuePair<string, uint>> | |||
| { | |||
| new KeyValuePair<string, uint>("expr", 2), | |||
| new KeyValuePair<string, uint>("expr_5", 5), | |||
| @@ -36,27 +34,11 @@ namespace LLama.Unittest | |||
| new KeyValuePair<string, uint>("term_7", 7), | |||
| }; | |||
| uint index = 0; | |||
| foreach (var it in state.SymbolIds) | |||
| foreach (var symbol in expected) | |||
| { | |||
| string key = it.Key; | |||
| uint value = it.Value; | |||
| var expectedPair = expected[(int)index]; | |||
| // pretty print error message before asserting | |||
| if (expectedPair.Key != key || expectedPair.Value != value) | |||
| { | |||
| Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | |||
| Console.Error.WriteLine($"actualPair: {key}, {value}"); | |||
| Console.Error.WriteLine("expectedPair != actualPair"); | |||
| } | |||
| Assert.Equal(expectedPair.Key, key); | |||
| Assert.Equal(expectedPair.Value, value); | |||
| index++; | |||
| var rule = state.Rules[(int)symbol.Value]; | |||
| Assert.Equal(symbol.Key, rule.Name); | |||
| } | |||
| Assert.NotEmpty(state.SymbolIds); | |||
| var expectedRules = new List<LLamaGrammarElement> | |||
| { | |||
| @@ -96,13 +78,13 @@ namespace LLama.Unittest | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }; | |||
| index = 0; | |||
| uint index = 0; | |||
| foreach (var rule in state.Rules) | |||
| { | |||
| // compare rule to expected rule | |||
| for (uint i = 0; i < rule.Count; i++) | |||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||
| { | |||
| var element = rule[(int)i]; | |||
| var element = rule.Elements[(int)i]; | |||
| var expectedElement = expectedRules[(int)index]; | |||
| // Pretty print error message before asserting | |||
| @@ -124,7 +106,7 @@ namespace LLama.Unittest | |||
| [Fact] | |||
| public void ParseExtraComplexGrammar() | |||
| { | |||
| GrammarParser parsedGrammar = new GrammarParser(); | |||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||
| string grammarBytes = @" | |||
| root ::= (expr ""="" ws term ""\n"")+ | |||
| expr ::= term ([-+*/] term)* | |||
| @@ -134,9 +116,10 @@ namespace LLama.Unittest | |||
| ws ::= [ \t\n]* | |||
| "; | |||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||
| Assert.Equal(0ul, state.StartRuleIndex); | |||
| List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | |||
| var expected = new List<KeyValuePair<string, uint>> | |||
| { | |||
| new KeyValuePair<string, uint>("expr", 2), | |||
| new KeyValuePair<string, uint>("expr_6", 6), | |||
| @@ -153,27 +136,11 @@ namespace LLama.Unittest | |||
| new KeyValuePair<string, uint>("ws_12", 12), | |||
| }; | |||
| uint index = 0; | |||
| foreach (var it in state.SymbolIds) | |||
| foreach (var symbol in expected) | |||
| { | |||
| string key = it.Key; | |||
| uint value = it.Value; | |||
| var expectedPair = expected[(int)index]; | |||
| // pretty print error message before asserting | |||
| if (expectedPair.Key != key || expectedPair.Value != value) | |||
| { | |||
| Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | |||
| Console.Error.WriteLine($"actualPair: {key}, {value}"); | |||
| Console.Error.WriteLine("expectedPair != actualPair"); | |||
| } | |||
| Assert.Equal(expectedPair.Key, key); | |||
| Assert.Equal(expectedPair.Value, value); | |||
| index++; | |||
| var rule = state.Rules[(int)symbol.Value]; | |||
| Assert.Equal(symbol.Key, rule.Name); | |||
| } | |||
| Assert.NotEmpty(state.SymbolIds); | |||
| var expectedRules = new List<LLamaGrammarElement> | |||
| { | |||
| @@ -246,13 +213,13 @@ namespace LLama.Unittest | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0) | |||
| }; | |||
| index = 0; | |||
| uint index = 0; | |||
| foreach (var rule in state.Rules) | |||
| { | |||
| // compare rule to expected rule | |||
| for (uint i = 0; i < rule.Count; i++) | |||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||
| { | |||
| var element = rule[(int)i]; | |||
| var element = rule.Elements[(int)i]; | |||
| var expectedElement = expectedRules[(int)index]; | |||
| // Pretty print error message before asserting | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Common; | |||
| using LLama.Grammars; | |||
| using LLama.Native; | |||
| namespace LLama.Unittest | |||
| @@ -26,14 +27,14 @@ namespace LLama.Unittest | |||
| [Fact] | |||
| public void CreateBasicGrammar() | |||
| { | |||
| var rules = new List<List<LLamaGrammarElement>> | |||
| var rules = new List<GrammarRule> | |||
| { | |||
| new() | |||
| new GrammarRule("alpha", new[] | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }, | |||
| }), | |||
| }; | |||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| @@ -44,15 +45,15 @@ namespace LLama.Unittest | |||
| { | |||
| // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | |||
| // we can be confident it's not what the LLM would say if not constrained by the grammar! | |||
| var rules = new List<List<LLamaGrammarElement>> | |||
| var rules = new List<GrammarRule> | |||
| { | |||
| new() | |||
| new GrammarRule("feline", new [] | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }, | |||
| }), | |||
| }; | |||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| @@ -0,0 +1,3 @@ | |||
| using System.Runtime.CompilerServices; | |||
| [assembly: InternalsVisibleTo("LLama.Unittest")] | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama.Extensions | |||
| { | |||
| internal static class IReadOnlyListExtensions | |||
| { | |||
| public static int? IndexOf<T>(this IReadOnlyList<T> list, T item) | |||
| where T : IEquatable<T> | |||
| { | |||
| for (var i = 0; i < list.Count; i++) | |||
| { | |||
| if (list[i].Equals(item)) | |||
| return i; | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,3 +6,5 @@ | |||
| using System.Diagnostics.CodeAnalysis; | |||
| [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] | |||
| [assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")] | |||
| @@ -1,181 +0,0 @@ | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| namespace LLama.Grammar | |||
| { | |||
| /// <summary> | |||
| /// Source: | |||
| /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h | |||
| /// | |||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | |||
| /// </summary> | |||
| public class ParseState | |||
| { | |||
| public SortedDictionary<string, uint> SymbolIds { get; } = new SortedDictionary<string, uint>(); | |||
| public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>(); | |||
| public IEnumerable<List<LLamaGrammarElement>> CRules() | |||
| { | |||
| foreach (var rule in Rules) | |||
| { | |||
| yield return rule; | |||
| } | |||
| } | |||
| public void PrintGrammar(StreamWriter file, ParseState state) | |||
| { | |||
| try | |||
| { | |||
| Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>(); | |||
| foreach (var kv in state.SymbolIds) | |||
| { | |||
| symbolIdNames[kv.Value] = kv.Key; | |||
| } | |||
| for (int i = 0, end = state.Rules.Count; i < end; i++) | |||
| { | |||
| PrintRule(file, (uint)i, state.Rules[i], symbolIdNames); | |||
| } | |||
| } | |||
| catch(Exception err) | |||
| { | |||
| Console.Error.WriteLine($"\nError printing grammar: {err.Message}"); | |||
| } | |||
| } | |||
| public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule) | |||
| { | |||
| foreach (var elem in rule) | |||
| { | |||
| switch (elem.Type) | |||
| { | |||
| case LLamaGrammarElementType.END: file.Write("END"); break; | |||
| case LLamaGrammarElementType.ALT: file.Write("ALT"); break; | |||
| case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break; | |||
| case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break; | |||
| case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break; | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break; | |||
| case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break; | |||
| } | |||
| switch (elem.Type) | |||
| { | |||
| case LLamaGrammarElementType.END: | |||
| case LLamaGrammarElementType.ALT: | |||
| case LLamaGrammarElementType.RULE_REF: | |||
| file.Write($"({elem.Value}) "); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR: | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| file.Write("(\""); | |||
| PrintGrammarChar(file, elem.Value); | |||
| file.Write("\") "); | |||
| break; | |||
| } | |||
| } | |||
| file.WriteLine(); | |||
| } | |||
| private void PrintRule( | |||
| StreamWriter file, | |||
| uint ruleId, | |||
| List<LLamaGrammarElement> rule, | |||
| Dictionary<uint, string> symbolIdNames) | |||
| { | |||
| if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END) | |||
| { | |||
| throw new GrammarFormatException( | |||
| $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}"); | |||
| } | |||
| file.Write($"{symbolIdNames[ruleId]} ::= "); | |||
| for (int i = 0, end = rule.Count - 1; i < end; i++) | |||
| { | |||
| var elem = rule[i]; | |||
| switch (elem.Type) | |||
| { | |||
| case LLamaGrammarElementType.END: | |||
| throw new GrammarFormatException( | |||
| $"Unexpected end of rule: {ruleId}, {i}"); | |||
| case LLamaGrammarElementType.ALT: | |||
| file.Write("| "); | |||
| break; | |||
| case LLamaGrammarElementType.RULE_REF: | |||
| file.Write($"{symbolIdNames[elem.Value]} "); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR: | |||
| file.Write("["); | |||
| PrintGrammarChar(file, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| file.Write("[^"); | |||
| PrintGrammarChar(file, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| if (i == 0 || !IsCharElement(rule[i - 1])) | |||
| { | |||
| throw new GrammarFormatException( | |||
| $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); | |||
| } | |||
| file.Write("-"); | |||
| PrintGrammarChar(file, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| if (i == 0 || !IsCharElement(rule[i - 1])) | |||
| { | |||
| throw new GrammarFormatException( | |||
| $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); | |||
| } | |||
| PrintGrammarChar(file, elem.Value); | |||
| break; | |||
| } | |||
| if (IsCharElement(elem)) | |||
| { | |||
| switch (rule[i + 1].Type) | |||
| { | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| break; | |||
| default: | |||
| file.Write("] "); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| file.WriteLine(); | |||
| } | |||
| private void PrintGrammarChar(StreamWriter file, uint c) | |||
| { | |||
| if (c >= 0x20 && c <= 0x7F) | |||
| { | |||
| file.Write((char)c); | |||
| } | |||
| else | |||
| { | |||
| // cop out of encoding UTF-8 | |||
| file.Write($"<U+{c:X4}>"); | |||
| } | |||
| } | |||
| private bool IsCharElement(LLamaGrammarElement elem) | |||
| { | |||
| switch (elem.Type) | |||
| { | |||
| case LLamaGrammarElementType.CHAR: | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| return true; | |||
| default: | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,11 @@ | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using System; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| namespace LLama.Grammar | |||
| namespace LLama.Grammars | |||
| { | |||
| /// <summary> | |||
| /// Source: | |||
| @@ -12,7 +13,7 @@ namespace LLama.Grammar | |||
| /// | |||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | |||
| /// </summary> | |||
| public class GrammarParser | |||
| internal sealed class GBNFGrammarParser | |||
| { | |||
| // NOTE: assumes valid utf8 (but checks for overrun) | |||
| // copied from llama.cpp | |||
| @@ -206,7 +207,7 @@ namespace LLama.Grammar | |||
| while (!pos.IsEmpty && pos[0] != '"') | |||
| { | |||
| var charPair = ParseChar(ref pos); | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); | |||
| } | |||
| pos = ParseSpace(pos.Slice(1), isNested); | |||
| } | |||
| @@ -228,13 +229,13 @@ namespace LLama.Grammar | |||
| var charPair = ParseChar(ref pos); | |||
| var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | |||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); | |||
| outElements.Add(new LLamaGrammarElement(type, charPair)); | |||
| if (pos[0] == '-' && pos[1] != ']') | |||
| { | |||
| pos = pos.Slice(1); | |||
| var endCharPair = ParseChar(ref pos); | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); | |||
| } | |||
| } | |||
| pos = ParseSpace(pos.Slice(1), isNested); | |||
| @@ -245,7 +246,7 @@ namespace LLama.Grammar | |||
| uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); | |||
| pos = ParseSpace(nameEnd, isNested); | |||
| lastSymStart = outElements.Count; | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId }); | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); | |||
| } | |||
| else if (pos[0] == '(') // grouping | |||
| { | |||
| @@ -255,7 +256,7 @@ namespace LLama.Grammar | |||
| pos = ParseAlternates(state, pos, ruleName, subRuleId, true); | |||
| lastSymStart = outElements.Count; | |||
| // output reference to synthesized rule | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||
| if (pos[0] != ')') | |||
| { | |||
| throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); | |||
| @@ -284,11 +285,11 @@ namespace LLama.Grammar | |||
| if (pos[0] == '*' || pos[0] == '+') | |||
| { | |||
| // cause generated rule to recurse | |||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||
| } | |||
| // mark start of alternate def | |||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||
| if (pos[0] == '+') | |||
| { | |||
| @@ -296,13 +297,13 @@ namespace LLama.Grammar | |||
| subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); | |||
| } | |||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||
| AddRule(state, subRuleId, subRule); | |||
| // in original rule, replace previous symbol with reference to generated rule | |||
| outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||
| pos = ParseSpace(pos.Slice(1), isNested); | |||
| @@ -328,12 +329,12 @@ namespace LLama.Grammar | |||
| while (!pos.IsEmpty && pos[0] == '|') | |||
| { | |||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||
| pos = ParseSpace(pos.Slice(1), true); | |||
| pos = ParseSequence(state, pos, ruleName, rule, isNested); | |||
| } | |||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||
| AddRule(state, ruleId, rule); | |||
| return pos; | |||
| @@ -370,19 +371,41 @@ namespace LLama.Grammar | |||
| return ParseSpace(pos, true); | |||
| } | |||
| public ParseState Parse(string input) | |||
| /// <summary> | |||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> | |||
| /// </summary> | |||
| /// <param name="input">The string to parse</param> | |||
| /// <param name="startRule">The name of the root rule of this grammar</param> | |||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||
| /// <returns>A ParseState that can be converted into a grammar for sampling</returns> | |||
| public Grammar Parse(string input, string startRule) | |||
| { | |||
| byte[] byteArray = Encoding.UTF8.GetBytes(input); | |||
| ReadOnlySpan<byte> src = new ReadOnlySpan<byte>(byteArray); | |||
| ParseState state = new ParseState(); | |||
| ReadOnlySpan<byte> pos = ParseSpace(src, true); | |||
| var byteArray = Encoding.UTF8.GetBytes(input); | |||
| var state = new ParseState(); | |||
| var pos = ParseSpace(byteArray, true); | |||
| while (!pos.IsEmpty) | |||
| { | |||
| pos = ParseRule(state, pos); | |||
| } | |||
| return state; | |||
| var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); | |||
| var rules = new List<GrammarRule>(); | |||
| for (var i = 0; i < state.Rules.Count; i++) | |||
| { | |||
| var elements = state.Rules[i]; | |||
| var name = names[(uint)i]; | |||
| rules.Add(new GrammarRule(name, elements)); | |||
| } | |||
| var startRuleIndex = state.SymbolIds[startRule]; | |||
| return new Grammar(rules, startRuleIndex); | |||
| } | |||
| private record ParseState | |||
| { | |||
| public SortedDictionary<string, uint> SymbolIds { get; } = new(); | |||
| public List<List<LLamaGrammarElement>> Rules { get; } = new(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,151 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| namespace LLama.Grammars | |||
| { | |||
| /// <summary> | |||
| /// A grammar is a set of <see cref="GrammarRule"/>s for deciding which characters are valid next. Can be used to constrain | |||
| /// output to certain formats - e.g. force the model to output JSON | |||
| /// </summary> | |||
| public sealed class Grammar | |||
| { | |||
| /// <summary> | |||
| /// Index of the initial rule to start from | |||
| /// </summary> | |||
| public ulong StartRuleIndex { get; set; } | |||
| /// <summary> | |||
| /// The rules which make up this grammar | |||
| /// </summary> | |||
| public IReadOnlyList<GrammarRule> Rules { get; } | |||
| /// <summary> | |||
| /// Create a new grammar from a set of rules | |||
| /// </summary> | |||
| /// <param name="rules">The rules which make up this grammar</param> | |||
| /// <param name="startRuleIndex">Index of the initial rule to start from</param> | |||
| /// <exception cref="ArgumentOutOfRangeException"></exception> | |||
| public Grammar(IReadOnlyList<GrammarRule> rules, ulong startRuleIndex) | |||
| { | |||
| if (startRuleIndex >= (uint)rules.Count) | |||
| throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules"); | |||
| StartRuleIndex = startRuleIndex; | |||
| Rules = rules; | |||
| } | |||
| /// <summary> | |||
| /// Create a `SafeLLamaGrammarHandle` instance to use for parsing | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public SafeLLamaGrammarHandle CreateInstance() | |||
| { | |||
| return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex); | |||
| } | |||
| /// <summary> | |||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> into a Grammar | |||
| /// </summary> | |||
| /// <param name="gbnf">The string to parse</param> | |||
| /// <param name="startRule">Name of the start rule of this grammar</param> | |||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||
| /// <returns>A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling</returns> | |||
| public static Grammar Parse(string gbnf, string startRule) | |||
| { | |||
| var parser = new GBNFGrammarParser(); | |||
| return parser.Parse(gbnf, startRule); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override string ToString() | |||
| { | |||
| var builder = new StringBuilder(); | |||
| PrintGrammar(builder); | |||
| return builder.ToString(); | |||
| } | |||
| private void PrintGrammar(StringBuilder output) | |||
| { | |||
| for (var i = 0; i < Rules.Count; i++) | |||
| PrintRule(output, (uint)i, Rules[i]); | |||
| } | |||
| private void PrintRule(StringBuilder output, uint ruleId, GrammarRule rule) | |||
| { | |||
| output.Append($"{rule.Name} ::= "); | |||
| for (int i = 0, end = rule.Elements.Count - 1; i < end; i++) | |||
| { | |||
| var elem = rule.Elements[i]; | |||
| switch (elem.Type) | |||
| { | |||
| case LLamaGrammarElementType.END: | |||
| throw new GrammarFormatException($"Unexpected end of rule: {ruleId}, {i}"); | |||
| case LLamaGrammarElementType.ALT: | |||
| output.Append("| "); | |||
| break; | |||
| case LLamaGrammarElementType.RULE_REF: | |||
| output.Append($"{Rules[(int)elem.Value].Name} "); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR: | |||
| output.Append('['); | |||
| PrintGrammarChar(output, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| output.Append("[^"); | |||
| PrintGrammarChar(output, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| if (i == 0 || !rule.Elements[i - 1].IsCharElement()) | |||
| { | |||
| throw new GrammarFormatException( | |||
| $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); | |||
| } | |||
| output.Append('-'); | |||
| PrintGrammarChar(output, elem.Value); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| if (i == 0 || !rule.Elements[i - 1].IsCharElement()) | |||
| { | |||
| throw new GrammarFormatException( | |||
| $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); | |||
| } | |||
| PrintGrammarChar(output, elem.Value); | |||
| break; | |||
| } | |||
| if (elem.IsCharElement()) | |||
| { | |||
| switch (rule.Elements[i + 1].Type) | |||
| { | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| break; | |||
| default: | |||
| output.Append("] "); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| output.AppendLine(); | |||
| } | |||
| private static void PrintGrammarChar(StringBuilder output, uint c) | |||
| { | |||
| if (c >= 0x20 && c <= 0x7F) | |||
| { | |||
| output.Append((char)c); | |||
| } | |||
| else | |||
| { | |||
| // cop out of encoding UTF-8 | |||
| output.Append($"<U+{c:X4}>"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,75 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using LLama.Native; | |||
| namespace LLama.Grammars | |||
| { | |||
| /// <summary> | |||
| /// A single rule in a <see cref="Grammar"/> | |||
| /// </summary> | |||
| public sealed record GrammarRule | |||
| { | |||
| /// <summary> | |||
| /// Name of this rule | |||
| /// </summary> | |||
| public string Name { get; } | |||
| /// <summary> | |||
| /// The elements of this grammar rule | |||
| /// </summary> | |||
| public IReadOnlyList<LLamaGrammarElement> Elements { get; } | |||
| /// <summary> | |||
| /// Create a new GrammarRule containing the given elements | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <param name="elements"></param> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public GrammarRule(string name, IReadOnlyList<LLamaGrammarElement> elements) | |||
| { | |||
| Validate(elements, name); | |||
| Name = name; | |||
| Elements = elements; | |||
| } | |||
| private static void Validate(IReadOnlyList<LLamaGrammarElement> elements, string name) | |||
| { | |||
| if (elements.Count == 0) | |||
| throw new ArgumentException("Cannot create a GrammaRule with zero elements", nameof(elements)); | |||
| if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END) | |||
| throw new ArgumentException("Last grammar element must be END", nameof(elements)); | |||
| for (var i = 0; i < elements.Count; i++) | |||
| { | |||
| switch (elements[i].Type) | |||
| { | |||
| case LLamaGrammarElementType.END: | |||
| if (i != elements.Count - 1) | |||
| throw new ArgumentException("Found more than one END grammar element", nameof(elements)); | |||
| continue; | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||
| throw new ArgumentException($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {name},{i}", nameof(elements)); | |||
| break; | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||
| { | |||
| throw new ArgumentException($"LLamaGrammarElementType.CHAR_ALT without preceding char: {name},{i}", nameof(elements)); | |||
| } | |||
| break; | |||
| case LLamaGrammarElementType.ALT: | |||
| case LLamaGrammarElementType.RULE_REF: | |||
| case LLamaGrammarElementType.CHAR: | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| break; | |||
| default: | |||
| throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Diagnostics; | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| @@ -51,17 +52,18 @@ namespace LLama.Native | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| [DebuggerDisplay("{Type} {Value}")] | |||
| public struct LLamaGrammarElement | |||
| public readonly struct LLamaGrammarElement | |||
| : IEquatable<LLamaGrammarElement> | |||
| { | |||
| /// <summary> | |||
| /// The type of this element | |||
| /// </summary> | |||
| public LLamaGrammarElementType Type; | |||
| public readonly LLamaGrammarElementType Type; | |||
| /// <summary> | |||
| /// Unicode code point or rule ID | |||
| /// </summary> | |||
| public uint Value; | |||
| public readonly uint Value; | |||
| /// <summary> | |||
| /// Construct a new LLamaGrammarElement | |||
| @@ -73,5 +75,50 @@ namespace LLama.Native | |||
| Type = type; | |||
| Value = value; | |||
| } | |||
| /// <inheritdoc /> | |||
| public bool Equals(LLamaGrammarElement other) | |||
| { | |||
| if (Type != other.Type) | |||
| return false; | |||
| // No need to compare values for the END rule | |||
| if (Type == LLamaGrammarElementType.END) | |||
| return true; | |||
| return Value == other.Value; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override bool Equals(object? obj) | |||
| { | |||
| return obj is LLamaGrammarElement other && Equals(other); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override int GetHashCode() | |||
| { | |||
| unchecked | |||
| { | |||
| var hash = 2999; | |||
| hash = hash * 7723 + (int)Type; | |||
| hash = hash * 7723 + (int)Value; | |||
| return hash; | |||
| } | |||
| } | |||
| internal bool IsCharElement() | |||
| { | |||
| switch (Type) | |||
| { | |||
| case LLamaGrammarElementType.CHAR: | |||
| case LLamaGrammarElementType.CHAR_NOT: | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||
| return true; | |||
| default: | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using LLama.Exceptions; | |||
| using LLama.Grammars; | |||
| namespace LLama.Native | |||
| { | |||
| @@ -38,11 +39,11 @@ namespace LLama.Native | |||
| /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index) | |||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index) | |||
| { | |||
| unsafe | |||
| { | |||
| var totalElements = rules.Sum(a => a.Count); | |||
| var totalElements = rules.Sum(a => a.Elements.Count); | |||
| var nrules = (ulong)rules.Count; | |||
| // Borrow an array large enough to hold every single element | |||
| @@ -61,7 +62,7 @@ namespace LLama.Native | |||
| pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | |||
| // Copy all of the rule elements into the flat array | |||
| foreach (var element in rule) | |||
| foreach (var element in rule.Elements) | |||
| allElementsPtr[elementIndex++] = element; | |||
| } | |||