- Made all the mechanics of grammar parsing (GBNFGrammarParser, ParseState) internal. Just call `Grammar.Parse("whatever")`.
- Added a `GrammarRule` class which validates elements on construction (this allows constructing grammar without parsing GBNF).
- It should be impossible for a `GrammarRule` to represent an invalid rule.
tags/v0.5.1
| @@ -1,6 +1,5 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Grammar; | |||||
| using LLama.Native; | |||||
| using LLama.Grammars; | |||||
| namespace LLama.Examples.NewVersion | namespace LLama.Examples.NewVersion | ||||
| { | { | ||||
| @@ -8,8 +7,8 @@ namespace LLama.Examples.NewVersion | |||||
| { | { | ||||
| public static void Run() | public static void Run() | ||||
| { | { | ||||
| var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim(); | |||||
| var parsedGrammar = new GrammarParser(); | |||||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||||
| var grammar = Grammar.Parse(gbnf, "root"); | |||||
| Console.Write("Please input your model path: "); | Console.Write("Please input your model path: "); | ||||
| var modelPath = Console.ReadLine(); | var modelPath = Console.ReadLine(); | ||||
| @@ -22,19 +21,18 @@ namespace LLama.Examples.NewVersion | |||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | using var model = LLamaWeights.LoadFromFile(parameters); | ||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||||
| using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0); | |||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); | Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); | ||||
| Console.ForegroundColor = ConsoleColor.White; | Console.ForegroundColor = ConsoleColor.White; | ||||
| using var grammarInstance = grammar.CreateInstance(); | |||||
| var inferenceParams = new InferenceParams() | var inferenceParams = new InferenceParams() | ||||
| { | { | ||||
| Temperature = 0.6f, | Temperature = 0.6f, | ||||
| AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, | AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, | ||||
| MaxTokens = 50, | MaxTokens = 50, | ||||
| Grammar = grammar | |||||
| Grammar = grammarInstance | |||||
| }; | }; | ||||
| while (true) | while (true) | ||||
| @@ -1,8 +1,5 @@ | |||||
| using LLama.Common; | |||||
| using LLama.Native; | |||||
| using System.Diagnostics; | |||||
| using LLama.Grammar; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using LLama.Native; | |||||
| using LLama.Grammars; | |||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| { | { | ||||
| @@ -17,14 +14,15 @@ namespace LLama.Unittest | |||||
| [Fact] | [Fact] | ||||
| public void ParseComplexGrammar() | public void ParseComplexGrammar() | ||||
| { | { | ||||
| GrammarParser parsedGrammar = new GrammarParser(); | |||||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||||
| string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ | string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ | ||||
| expr ::= term ([-+*/] term)* | expr ::= term ([-+*/] term)* | ||||
| term ::= [0-9]+"; | term ::= [0-9]+"; | ||||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||||
| Assert.Equal(0ul, state.StartRuleIndex); | |||||
| List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | { | ||||
| new KeyValuePair<string, uint>("expr", 2), | new KeyValuePair<string, uint>("expr", 2), | ||||
| new KeyValuePair<string, uint>("expr_5", 5), | new KeyValuePair<string, uint>("expr_5", 5), | ||||
| @@ -36,27 +34,11 @@ namespace LLama.Unittest | |||||
| new KeyValuePair<string, uint>("term_7", 7), | new KeyValuePair<string, uint>("term_7", 7), | ||||
| }; | }; | ||||
| uint index = 0; | |||||
| foreach (var it in state.SymbolIds) | |||||
| foreach (var symbol in expected) | |||||
| { | { | ||||
| string key = it.Key; | |||||
| uint value = it.Value; | |||||
| var expectedPair = expected[(int)index]; | |||||
| // pretty print error message before asserting | |||||
| if (expectedPair.Key != key || expectedPair.Value != value) | |||||
| { | |||||
| Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | |||||
| Console.Error.WriteLine($"actualPair: {key}, {value}"); | |||||
| Console.Error.WriteLine("expectedPair != actualPair"); | |||||
| } | |||||
| Assert.Equal(expectedPair.Key, key); | |||||
| Assert.Equal(expectedPair.Value, value); | |||||
| index++; | |||||
| var rule = state.Rules[(int)symbol.Value]; | |||||
| Assert.Equal(symbol.Key, rule.Name); | |||||
| } | } | ||||
| Assert.NotEmpty(state.SymbolIds); | |||||
| var expectedRules = new List<LLamaGrammarElement> | var expectedRules = new List<LLamaGrammarElement> | ||||
| { | { | ||||
| @@ -96,13 +78,13 @@ namespace LLama.Unittest | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||||
| }; | }; | ||||
| index = 0; | |||||
| uint index = 0; | |||||
| foreach (var rule in state.Rules) | foreach (var rule in state.Rules) | ||||
| { | { | ||||
| // compare rule to expected rule | // compare rule to expected rule | ||||
| for (uint i = 0; i < rule.Count; i++) | |||||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||||
| { | { | ||||
| var element = rule[(int)i]; | |||||
| var element = rule.Elements[(int)i]; | |||||
| var expectedElement = expectedRules[(int)index]; | var expectedElement = expectedRules[(int)index]; | ||||
| // Pretty print error message before asserting | // Pretty print error message before asserting | ||||
| @@ -124,7 +106,7 @@ namespace LLama.Unittest | |||||
| [Fact] | [Fact] | ||||
| public void ParseExtraComplexGrammar() | public void ParseExtraComplexGrammar() | ||||
| { | { | ||||
| GrammarParser parsedGrammar = new GrammarParser(); | |||||
| GBNFGrammarParser parsedGrammar = new GBNFGrammarParser(); | |||||
| string grammarBytes = @" | string grammarBytes = @" | ||||
| root ::= (expr ""="" ws term ""\n"")+ | root ::= (expr ""="" ws term ""\n"")+ | ||||
| expr ::= term ([-+*/] term)* | expr ::= term ([-+*/] term)* | ||||
| @@ -134,9 +116,10 @@ namespace LLama.Unittest | |||||
| ws ::= [ \t\n]* | ws ::= [ \t\n]* | ||||
| "; | "; | ||||
| ParseState state = parsedGrammar.Parse(grammarBytes); | |||||
| var state = parsedGrammar.Parse(grammarBytes, "root"); | |||||
| Assert.Equal(0ul, state.StartRuleIndex); | |||||
| List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | { | ||||
| new KeyValuePair<string, uint>("expr", 2), | new KeyValuePair<string, uint>("expr", 2), | ||||
| new KeyValuePair<string, uint>("expr_6", 6), | new KeyValuePair<string, uint>("expr_6", 6), | ||||
| @@ -153,27 +136,11 @@ namespace LLama.Unittest | |||||
| new KeyValuePair<string, uint>("ws_12", 12), | new KeyValuePair<string, uint>("ws_12", 12), | ||||
| }; | }; | ||||
| uint index = 0; | |||||
| foreach (var it in state.SymbolIds) | |||||
| foreach (var symbol in expected) | |||||
| { | { | ||||
| string key = it.Key; | |||||
| uint value = it.Value; | |||||
| var expectedPair = expected[(int)index]; | |||||
| // pretty print error message before asserting | |||||
| if (expectedPair.Key != key || expectedPair.Value != value) | |||||
| { | |||||
| Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | |||||
| Console.Error.WriteLine($"actualPair: {key}, {value}"); | |||||
| Console.Error.WriteLine("expectedPair != actualPair"); | |||||
| } | |||||
| Assert.Equal(expectedPair.Key, key); | |||||
| Assert.Equal(expectedPair.Value, value); | |||||
| index++; | |||||
| var rule = state.Rules[(int)symbol.Value]; | |||||
| Assert.Equal(symbol.Key, rule.Name); | |||||
| } | } | ||||
| Assert.NotEmpty(state.SymbolIds); | |||||
| var expectedRules = new List<LLamaGrammarElement> | var expectedRules = new List<LLamaGrammarElement> | ||||
| { | { | ||||
| @@ -246,13 +213,13 @@ namespace LLama.Unittest | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0) | new LLamaGrammarElement(LLamaGrammarElementType.END, 0) | ||||
| }; | }; | ||||
| index = 0; | |||||
| uint index = 0; | |||||
| foreach (var rule in state.Rules) | foreach (var rule in state.Rules) | ||||
| { | { | ||||
| // compare rule to expected rule | // compare rule to expected rule | ||||
| for (uint i = 0; i < rule.Count; i++) | |||||
| for (uint i = 0; i < rule.Elements.Count; i++) | |||||
| { | { | ||||
| var element = rule[(int)i]; | |||||
| var element = rule.Elements[(int)i]; | |||||
| var expectedElement = expectedRules[(int)index]; | var expectedElement = expectedRules[(int)index]; | ||||
| // Pretty print error message before asserting | // Pretty print error message before asserting | ||||
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Common; | using LLama.Common; | ||||
| using LLama.Grammars; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| namespace LLama.Unittest | namespace LLama.Unittest | ||||
| @@ -26,14 +27,14 @@ namespace LLama.Unittest | |||||
| [Fact] | [Fact] | ||||
| public void CreateBasicGrammar() | public void CreateBasicGrammar() | ||||
| { | { | ||||
| var rules = new List<List<LLamaGrammarElement>> | |||||
| var rules = new List<GrammarRule> | |||||
| { | { | ||||
| new() | |||||
| new GrammarRule("alpha", new[] | |||||
| { | { | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||||
| }, | |||||
| }), | |||||
| }; | }; | ||||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | ||||
| @@ -44,15 +45,15 @@ namespace LLama.Unittest | |||||
| { | { | ||||
| // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so | ||||
| // we can be confident it's not what the LLM would say if not constrained by the grammar! | // we can be confident it's not what the LLM would say if not constrained by the grammar! | ||||
| var rules = new List<List<LLamaGrammarElement>> | |||||
| var rules = new List<GrammarRule> | |||||
| { | { | ||||
| new() | |||||
| new GrammarRule("feline", new [] | |||||
| { | { | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), | new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), | new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), | ||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||||
| }, | |||||
| }), | |||||
| }; | }; | ||||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | ||||
| @@ -0,0 +1,3 @@ | |||||
| using System.Runtime.CompilerServices; | |||||
| [assembly: InternalsVisibleTo("LLama.Unittest")] | |||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace LLama.Extensions | |||||
| { | |||||
| internal static class IReadOnlyListExtensions | |||||
| { | |||||
| public static int? IndexOf<T>(this IReadOnlyList<T> list, T item) | |||||
| where T : IEquatable<T> | |||||
| { | |||||
| for (var i = 0; i < list.Count; i++) | |||||
| { | |||||
| if (list[i].Equals(item)) | |||||
| return i; | |||||
| } | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,3 +6,5 @@ | |||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] | [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] | ||||
| [assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")] | |||||
| @@ -1,181 +0,0 @@ | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| namespace LLama.Grammar | |||||
| { | |||||
| /// <summary> | |||||
| /// Source: | |||||
| /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h | |||||
| /// | |||||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | |||||
| /// </summary> | |||||
| public class ParseState | |||||
| { | |||||
| public SortedDictionary<string, uint> SymbolIds { get; } = new SortedDictionary<string, uint>(); | |||||
| public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>(); | |||||
| public IEnumerable<List<LLamaGrammarElement>> CRules() | |||||
| { | |||||
| foreach (var rule in Rules) | |||||
| { | |||||
| yield return rule; | |||||
| } | |||||
| } | |||||
| public void PrintGrammar(StreamWriter file, ParseState state) | |||||
| { | |||||
| try | |||||
| { | |||||
| Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>(); | |||||
| foreach (var kv in state.SymbolIds) | |||||
| { | |||||
| symbolIdNames[kv.Value] = kv.Key; | |||||
| } | |||||
| for (int i = 0, end = state.Rules.Count; i < end; i++) | |||||
| { | |||||
| PrintRule(file, (uint)i, state.Rules[i], symbolIdNames); | |||||
| } | |||||
| } | |||||
| catch(Exception err) | |||||
| { | |||||
| Console.Error.WriteLine($"\nError printing grammar: {err.Message}"); | |||||
| } | |||||
| } | |||||
| public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule) | |||||
| { | |||||
| foreach (var elem in rule) | |||||
| { | |||||
| switch (elem.Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: file.Write("END"); break; | |||||
| case LLamaGrammarElementType.ALT: file.Write("ALT"); break; | |||||
| case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break; | |||||
| case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break; | |||||
| case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break; | |||||
| } | |||||
| switch (elem.Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: | |||||
| case LLamaGrammarElementType.ALT: | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| file.Write($"({elem.Value}) "); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| file.Write("(\""); | |||||
| PrintGrammarChar(file, elem.Value); | |||||
| file.Write("\") "); | |||||
| break; | |||||
| } | |||||
| } | |||||
| file.WriteLine(); | |||||
| } | |||||
| private void PrintRule( | |||||
| StreamWriter file, | |||||
| uint ruleId, | |||||
| List<LLamaGrammarElement> rule, | |||||
| Dictionary<uint, string> symbolIdNames) | |||||
| { | |||||
| if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END) | |||||
| { | |||||
| throw new GrammarFormatException( | |||||
| $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}"); | |||||
| } | |||||
| file.Write($"{symbolIdNames[ruleId]} ::= "); | |||||
| for (int i = 0, end = rule.Count - 1; i < end; i++) | |||||
| { | |||||
| var elem = rule[i]; | |||||
| switch (elem.Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: | |||||
| throw new GrammarFormatException( | |||||
| $"Unexpected end of rule: {ruleId}, {i}"); | |||||
| case LLamaGrammarElementType.ALT: | |||||
| file.Write("| "); | |||||
| break; | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| file.Write($"{symbolIdNames[elem.Value]} "); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| file.Write("["); | |||||
| PrintGrammarChar(file, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| file.Write("[^"); | |||||
| PrintGrammarChar(file, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| if (i == 0 || !IsCharElement(rule[i - 1])) | |||||
| { | |||||
| throw new GrammarFormatException( | |||||
| $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); | |||||
| } | |||||
| file.Write("-"); | |||||
| PrintGrammarChar(file, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| if (i == 0 || !IsCharElement(rule[i - 1])) | |||||
| { | |||||
| throw new GrammarFormatException( | |||||
| $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); | |||||
| } | |||||
| PrintGrammarChar(file, elem.Value); | |||||
| break; | |||||
| } | |||||
| if (IsCharElement(elem)) | |||||
| { | |||||
| switch (rule[i + 1].Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| break; | |||||
| default: | |||||
| file.Write("] "); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| file.WriteLine(); | |||||
| } | |||||
| private void PrintGrammarChar(StreamWriter file, uint c) | |||||
| { | |||||
| if (c >= 0x20 && c <= 0x7F) | |||||
| { | |||||
| file.Write((char)c); | |||||
| } | |||||
| else | |||||
| { | |||||
| // cop out of encoding UTF-8 | |||||
| file.Write($"<U+{c:X4}>"); | |||||
| } | |||||
| } | |||||
| private bool IsCharElement(LLamaGrammarElement elem) | |||||
| { | |||||
| switch (elem.Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| return true; | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,10 +1,11 @@ | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| using System; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammar | |||||
| namespace LLama.Grammars | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Source: | /// Source: | ||||
| @@ -12,7 +13,7 @@ namespace LLama.Grammar | |||||
| /// | /// | ||||
| /// The commit hash from URL is the actual commit hash that reflects current C# code. | /// The commit hash from URL is the actual commit hash that reflects current C# code. | ||||
| /// </summary> | /// </summary> | ||||
| public class GrammarParser | |||||
| internal sealed class GBNFGrammarParser | |||||
| { | { | ||||
| // NOTE: assumes valid utf8 (but checks for overrun) | // NOTE: assumes valid utf8 (but checks for overrun) | ||||
| // copied from llama.cpp | // copied from llama.cpp | ||||
| @@ -206,7 +207,7 @@ namespace LLama.Grammar | |||||
| while (!pos.IsEmpty && pos[0] != '"') | while (!pos.IsEmpty && pos[0] != '"') | ||||
| { | { | ||||
| var charPair = ParseChar(ref pos); | var charPair = ParseChar(ref pos); | ||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair)); | |||||
| } | } | ||||
| pos = ParseSpace(pos.Slice(1), isNested); | pos = ParseSpace(pos.Slice(1), isNested); | ||||
| } | } | ||||
| @@ -228,13 +229,13 @@ namespace LLama.Grammar | |||||
| var charPair = ParseChar(ref pos); | var charPair = ParseChar(ref pos); | ||||
| var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | ||||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); | |||||
| outElements.Add(new LLamaGrammarElement(type, charPair)); | |||||
| if (pos[0] == '-' && pos[1] != ']') | if (pos[0] == '-' && pos[1] != ']') | ||||
| { | { | ||||
| pos = pos.Slice(1); | pos = pos.Slice(1); | ||||
| var endCharPair = ParseChar(ref pos); | var endCharPair = ParseChar(ref pos); | ||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair)); | |||||
| } | } | ||||
| } | } | ||||
| pos = ParseSpace(pos.Slice(1), isNested); | pos = ParseSpace(pos.Slice(1), isNested); | ||||
| @@ -245,7 +246,7 @@ namespace LLama.Grammar | |||||
| uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); | uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); | ||||
| pos = ParseSpace(nameEnd, isNested); | pos = ParseSpace(nameEnd, isNested); | ||||
| lastSymStart = outElements.Count; | lastSymStart = outElements.Count; | ||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId }); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); | |||||
| } | } | ||||
| else if (pos[0] == '(') // grouping | else if (pos[0] == '(') // grouping | ||||
| { | { | ||||
| @@ -255,7 +256,7 @@ namespace LLama.Grammar | |||||
| pos = ParseAlternates(state, pos, ruleName, subRuleId, true); | pos = ParseAlternates(state, pos, ruleName, subRuleId, true); | ||||
| lastSymStart = outElements.Count; | lastSymStart = outElements.Count; | ||||
| // output reference to synthesized rule | // output reference to synthesized rule | ||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| if (pos[0] != ')') | if (pos[0] != ')') | ||||
| { | { | ||||
| throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); | throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); | ||||
| @@ -284,11 +285,11 @@ namespace LLama.Grammar | |||||
| if (pos[0] == '*' || pos[0] == '+') | if (pos[0] == '*' || pos[0] == '+') | ||||
| { | { | ||||
| // cause generated rule to recurse | // cause generated rule to recurse | ||||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| } | } | ||||
| // mark start of alternate def | // mark start of alternate def | ||||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||||
| if (pos[0] == '+') | if (pos[0] == '+') | ||||
| { | { | ||||
| @@ -296,13 +297,13 @@ namespace LLama.Grammar | |||||
| subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); | subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); | ||||
| } | } | ||||
| subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||||
| AddRule(state, subRuleId, subRule); | AddRule(state, subRuleId, subRule); | ||||
| // in original rule, replace previous symbol with reference to generated rule | // in original rule, replace previous symbol with reference to generated rule | ||||
| outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); | outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); | ||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); | |||||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId)); | |||||
| pos = ParseSpace(pos.Slice(1), isNested); | pos = ParseSpace(pos.Slice(1), isNested); | ||||
| @@ -328,12 +329,12 @@ namespace LLama.Grammar | |||||
| while (!pos.IsEmpty && pos[0] == '|') | while (!pos.IsEmpty && pos[0] == '|') | ||||
| { | { | ||||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0)); | |||||
| pos = ParseSpace(pos.Slice(1), true); | pos = ParseSpace(pos.Slice(1), true); | ||||
| pos = ParseSequence(state, pos, ruleName, rule, isNested); | pos = ParseSequence(state, pos, ruleName, rule, isNested); | ||||
| } | } | ||||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||||
| AddRule(state, ruleId, rule); | AddRule(state, ruleId, rule); | ||||
| return pos; | return pos; | ||||
| @@ -370,19 +371,41 @@ namespace LLama.Grammar | |||||
| return ParseSpace(pos, true); | return ParseSpace(pos, true); | ||||
| } | } | ||||
| public ParseState Parse(string input) | |||||
| /// <summary> | |||||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> | |||||
| /// </summary> | |||||
| /// <param name="input">The string to parse</param> | |||||
| /// <param name="startRule">The name of the root rule of this grammar</param> | |||||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||||
| /// <returns>A ParseState that can be converted into a grammar for sampling</returns> | |||||
| public Grammar Parse(string input, string startRule) | |||||
| { | { | ||||
| byte[] byteArray = Encoding.UTF8.GetBytes(input); | |||||
| ReadOnlySpan<byte> src = new ReadOnlySpan<byte>(byteArray); | |||||
| ParseState state = new ParseState(); | |||||
| ReadOnlySpan<byte> pos = ParseSpace(src, true); | |||||
| var byteArray = Encoding.UTF8.GetBytes(input); | |||||
| var state = new ParseState(); | |||||
| var pos = ParseSpace(byteArray, true); | |||||
| while (!pos.IsEmpty) | while (!pos.IsEmpty) | ||||
| { | { | ||||
| pos = ParseRule(state, pos); | pos = ParseRule(state, pos); | ||||
| } | } | ||||
| return state; | |||||
| var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key); | |||||
| var rules = new List<GrammarRule>(); | |||||
| for (var i = 0; i < state.Rules.Count; i++) | |||||
| { | |||||
| var elements = state.Rules[i]; | |||||
| var name = names[(uint)i]; | |||||
| rules.Add(new GrammarRule(name, elements)); | |||||
| } | |||||
| var startRuleIndex = state.SymbolIds[startRule]; | |||||
| return new Grammar(rules, startRuleIndex); | |||||
| } | |||||
| private record ParseState | |||||
| { | |||||
| public SortedDictionary<string, uint> SymbolIds { get; } = new(); | |||||
| public List<List<LLamaGrammarElement>> Rules { get; } = new(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,151 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammars | |||||
| { | |||||
| /// <summary> | |||||
| /// A grammar is a set of <see cref="GrammarRule"/>s for deciding which characters are valid next. Can be used to constrain | |||||
| /// output to certain formats - e.g. force the model to output JSON | |||||
| /// </summary> | |||||
| public sealed class Grammar | |||||
| { | |||||
| /// <summary> | |||||
| /// Index of the initial rule to start from | |||||
| /// </summary> | |||||
| public ulong StartRuleIndex { get; set; } | |||||
| /// <summary> | |||||
| /// The rules which make up this grammar | |||||
| /// </summary> | |||||
| public IReadOnlyList<GrammarRule> Rules { get; } | |||||
| /// <summary> | |||||
| /// Create a new grammar from a set of rules | |||||
| /// </summary> | |||||
| /// <param name="rules">The rules which make up this grammar</param> | |||||
| /// <param name="startRuleIndex">Index of the initial rule to start from</param> | |||||
| /// <exception cref="ArgumentOutOfRangeException"></exception> | |||||
| public Grammar(IReadOnlyList<GrammarRule> rules, ulong startRuleIndex) | |||||
| { | |||||
| if (startRuleIndex >= (uint)rules.Count) | |||||
| throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules"); | |||||
| StartRuleIndex = startRuleIndex; | |||||
| Rules = rules; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create a `SafeLLamaGrammarHandle` instance to use for parsing | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public SafeLLamaGrammarHandle CreateInstance() | |||||
| { | |||||
| return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex); | |||||
| } | |||||
| /// <summary> | |||||
| /// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> into a Grammar | |||||
| /// </summary> | |||||
| /// <param name="gbnf">The string to parse</param> | |||||
| /// <param name="startRule">Name of the start rule of this grammar</param> | |||||
| /// <exception cref="GrammarFormatException">Thrown if input is malformed</exception> | |||||
| /// <returns>A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling</returns> | |||||
| public static Grammar Parse(string gbnf, string startRule) | |||||
| { | |||||
| var parser = new GBNFGrammarParser(); | |||||
| return parser.Parse(gbnf, startRule); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override string ToString() | |||||
| { | |||||
| var builder = new StringBuilder(); | |||||
| PrintGrammar(builder); | |||||
| return builder.ToString(); | |||||
| } | |||||
| private void PrintGrammar(StringBuilder output) | |||||
| { | |||||
| for (var i = 0; i < Rules.Count; i++) | |||||
| PrintRule(output, (uint)i, Rules[i]); | |||||
| } | |||||
| private void PrintRule(StringBuilder output, uint ruleId, GrammarRule rule) | |||||
| { | |||||
| output.Append($"{rule.Name} ::= "); | |||||
| for (int i = 0, end = rule.Elements.Count - 1; i < end; i++) | |||||
| { | |||||
| var elem = rule.Elements[i]; | |||||
| switch (elem.Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: | |||||
| throw new GrammarFormatException($"Unexpected end of rule: {ruleId}, {i}"); | |||||
| case LLamaGrammarElementType.ALT: | |||||
| output.Append("| "); | |||||
| break; | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| output.Append($"{Rules[(int)elem.Value].Name} "); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| output.Append('['); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| output.Append("[^"); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| if (i == 0 || !rule.Elements[i - 1].IsCharElement()) | |||||
| { | |||||
| throw new GrammarFormatException( | |||||
| $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); | |||||
| } | |||||
| output.Append('-'); | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| if (i == 0 || !rule.Elements[i - 1].IsCharElement()) | |||||
| { | |||||
| throw new GrammarFormatException( | |||||
| $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); | |||||
| } | |||||
| PrintGrammarChar(output, elem.Value); | |||||
| break; | |||||
| } | |||||
| if (elem.IsCharElement()) | |||||
| { | |||||
| switch (rule.Elements[i + 1].Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| break; | |||||
| default: | |||||
| output.Append("] "); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| output.AppendLine(); | |||||
| } | |||||
| private static void PrintGrammarChar(StringBuilder output, uint c) | |||||
| { | |||||
| if (c >= 0x20 && c <= 0x7F) | |||||
| { | |||||
| output.Append((char)c); | |||||
| } | |||||
| else | |||||
| { | |||||
| // cop out of encoding UTF-8 | |||||
| output.Append($"<U+{c:X4}>"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,75 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using LLama.Native; | |||||
| namespace LLama.Grammars | |||||
| { | |||||
| /// <summary> | |||||
| /// A single rule in a <see cref="Grammar"/> | |||||
| /// </summary> | |||||
| public sealed record GrammarRule | |||||
| { | |||||
| /// <summary> | |||||
| /// Name of this rule | |||||
| /// </summary> | |||||
| public string Name { get; } | |||||
| /// <summary> | |||||
| /// The elements of this grammar rule | |||||
| /// </summary> | |||||
| public IReadOnlyList<LLamaGrammarElement> Elements { get; } | |||||
| /// <summary> | |||||
| /// Create a new GrammarRule containing the given elements | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="elements"></param> | |||||
| /// <exception cref="ArgumentException"></exception> | |||||
| public GrammarRule(string name, IReadOnlyList<LLamaGrammarElement> elements) | |||||
| { | |||||
| Validate(elements, name); | |||||
| Name = name; | |||||
| Elements = elements; | |||||
| } | |||||
| private static void Validate(IReadOnlyList<LLamaGrammarElement> elements, string name) | |||||
| { | |||||
| if (elements.Count == 0) | |||||
| throw new ArgumentException("Cannot create a GrammaRule with zero elements", nameof(elements)); | |||||
| if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END) | |||||
| throw new ArgumentException("Last grammar element must be END", nameof(elements)); | |||||
| for (var i = 0; i < elements.Count; i++) | |||||
| { | |||||
| switch (elements[i].Type) | |||||
| { | |||||
| case LLamaGrammarElementType.END: | |||||
| if (i != elements.Count - 1) | |||||
| throw new ArgumentException("Found more than one END grammar element", nameof(elements)); | |||||
| continue; | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||||
| throw new ArgumentException($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {name},{i}", nameof(elements)); | |||||
| break; | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| if (i == 0 || !elements[i - 1].IsCharElement()) | |||||
| { | |||||
| throw new ArgumentException($"LLamaGrammarElementType.CHAR_ALT without preceding char: {name},{i}", nameof(elements)); | |||||
| } | |||||
| break; | |||||
| case LLamaGrammarElementType.ALT: | |||||
| case LLamaGrammarElementType.RULE_REF: | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| break; | |||||
| default: | |||||
| throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System.Diagnostics; | |||||
| using System; | |||||
| using System.Diagnostics; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| namespace LLama.Native | namespace LLama.Native | ||||
| @@ -51,17 +52,18 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| [DebuggerDisplay("{Type} {Value}")] | [DebuggerDisplay("{Type} {Value}")] | ||||
| public struct LLamaGrammarElement | |||||
| public readonly struct LLamaGrammarElement | |||||
| : IEquatable<LLamaGrammarElement> | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The type of this element | /// The type of this element | ||||
| /// </summary> | /// </summary> | ||||
| public LLamaGrammarElementType Type; | |||||
| public readonly LLamaGrammarElementType Type; | |||||
| /// <summary> | /// <summary> | ||||
| /// Unicode code point or rule ID | /// Unicode code point or rule ID | ||||
| /// </summary> | /// </summary> | ||||
| public uint Value; | |||||
| public readonly uint Value; | |||||
| /// <summary> | /// <summary> | ||||
| /// Construct a new LLamaGrammarElement | /// Construct a new LLamaGrammarElement | ||||
| @@ -73,5 +75,50 @@ namespace LLama.Native | |||||
| Type = type; | Type = type; | ||||
| Value = value; | Value = value; | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public bool Equals(LLamaGrammarElement other) | |||||
| { | |||||
| if (Type != other.Type) | |||||
| return false; | |||||
| // No need to compare values for the END rule | |||||
| if (Type == LLamaGrammarElementType.END) | |||||
| return true; | |||||
| return Value == other.Value; | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override bool Equals(object? obj) | |||||
| { | |||||
| return obj is LLamaGrammarElement other && Equals(other); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| public override int GetHashCode() | |||||
| { | |||||
| unchecked | |||||
| { | |||||
| var hash = 2999; | |||||
| hash = hash * 7723 + (int)Type; | |||||
| hash = hash * 7723 + (int)Value; | |||||
| return hash; | |||||
| } | |||||
| } | |||||
| internal bool IsCharElement() | |||||
| { | |||||
| switch (Type) | |||||
| { | |||||
| case LLamaGrammarElementType.CHAR: | |||||
| case LLamaGrammarElementType.CHAR_NOT: | |||||
| case LLamaGrammarElementType.CHAR_ALT: | |||||
| case LLamaGrammarElementType.CHAR_RNG_UPPER: | |||||
| return true; | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using LLama.Grammars; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| @@ -38,11 +39,11 @@ namespace LLama.Native | |||||
| /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | /// <param name="start_rule_index">The index (in the outer list) of the start rule</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index) | |||||
| public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index) | |||||
| { | { | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var totalElements = rules.Sum(a => a.Count); | |||||
| var totalElements = rules.Sum(a => a.Elements.Count); | |||||
| var nrules = (ulong)rules.Count; | var nrules = (ulong)rules.Count; | ||||
| // Borrow an array large enough to hold every single element | // Borrow an array large enough to hold every single element | ||||
| @@ -61,7 +62,7 @@ namespace LLama.Native | |||||
| pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); | ||||
| // Copy all of the rule elements into the flat array | // Copy all of the rule elements into the flat array | ||||
| foreach (var element in rule) | |||||
| foreach (var element in rule.Elements) | |||||
| allElementsPtr[elementIndex++] = element; | allElementsPtr[elementIndex++] = element; | ||||
| } | } | ||||