Browse Source

- Created a higher level `Grammar` class which is immutable and contains a list of grammar rules. This is the main "entry point" to the grammar system.

- Made all the mechanics of grammar parsing (GBNFGrammarParser, ParseState) internal. Just call `Grammar.Parse("whatever")`.
 - Added a `GrammarRule` class which validates elements on construction (this allows constructing grammar without parsing GBNF).
   - It should be impossible for a `GrammarRule` to represent an invalid rule.
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
a70c7170dd
12 changed files with 385 additions and 278 deletions
  1. +5
    -7
      LLama.Examples/NewVersion/GrammarJsonResponse.cs
  2. +22
    -55
      LLama.Unittest/GrammarParserTest.cs
  3. +7
    -6
      LLama.Unittest/GrammarTest.cs
  4. +3
    -0
      LLama/AssemblyAttributes.cs
  5. +20
    -0
      LLama/Extensions/IReadOnlyListExtensions.cs
  6. +2
    -0
      LLama/GlobalSuppressions.cs
  7. +0
    -181
      LLama/Grammar/ParseState.cs
  8. +45
    -22
      LLama/Grammars/GBNFGrammarParser.cs
  9. +151
    -0
      LLama/Grammars/Grammar.cs
  10. +75
    -0
      LLama/Grammars/GrammarRule.cs
  11. +51
    -4
      LLama/Native/LLamaGrammarElement.cs
  12. +4
    -3
      LLama/Native/SafeLLamaGrammarHandle.cs

+ 5
- 7
LLama.Examples/NewVersion/GrammarJsonResponse.cs View File

@@ -1,6 +1,5 @@
using LLama.Common; using LLama.Common;
using LLama.Grammar;
using LLama.Native;
using LLama.Grammars;


namespace LLama.Examples.NewVersion namespace LLama.Examples.NewVersion
{ {
@@ -8,8 +7,8 @@ namespace LLama.Examples.NewVersion
{ {
public static void Run() public static void Run()
{ {
var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim();
var parsedGrammar = new GrammarParser();
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var grammar = Grammar.Parse(gbnf, "root");


Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@@ -22,19 +21,18 @@ namespace LLama.Examples.NewVersion
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);
ParseState state = parsedGrammar.Parse(grammarBytes);
using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0);


Console.ForegroundColor = ConsoleColor.Yellow; Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\"");
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;


using var grammarInstance = grammar.CreateInstance();
var inferenceParams = new InferenceParams() var inferenceParams = new InferenceParams()
{ {
Temperature = 0.6f, Temperature = 0.6f,
AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" },
MaxTokens = 50, MaxTokens = 50,
Grammar = grammar
Grammar = grammarInstance
}; };


while (true) while (true)


+ 22
- 55
LLama.Unittest/GrammarParserTest.cs View File

@@ -1,8 +1,5 @@
using LLama.Common;
using LLama.Native;
using System.Diagnostics;
using LLama.Grammar;
using Newtonsoft.Json.Linq;
using LLama.Native;
using LLama.Grammars;


namespace LLama.Unittest namespace LLama.Unittest
{ {
@@ -17,14 +14,15 @@ namespace LLama.Unittest
[Fact] [Fact]
public void ParseComplexGrammar() public void ParseComplexGrammar()
{ {
GrammarParser parsedGrammar = new GrammarParser();
GBNFGrammarParser parsedGrammar = new GBNFGrammarParser();
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
expr ::= term ([-+*/] term)* expr ::= term ([-+*/] term)*
term ::= [0-9]+"; term ::= [0-9]+";


ParseState state = parsedGrammar.Parse(grammarBytes);
var state = parsedGrammar.Parse(grammarBytes, "root");
Assert.Equal(0ul, state.StartRuleIndex);


List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
var expected = new List<KeyValuePair<string, uint>>
{ {
new KeyValuePair<string, uint>("expr", 2), new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_5", 5), new KeyValuePair<string, uint>("expr_5", 5),
@@ -36,27 +34,11 @@ namespace LLama.Unittest
new KeyValuePair<string, uint>("term_7", 7), new KeyValuePair<string, uint>("term_7", 7),
}; };


uint index = 0;
foreach (var it in state.SymbolIds)
foreach (var symbol in expected)
{ {
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
var rule = state.Rules[(int)symbol.Value];
Assert.Equal(symbol.Key, rule.Name);
} }
Assert.NotEmpty(state.SymbolIds);



var expectedRules = new List<LLamaGrammarElement> var expectedRules = new List<LLamaGrammarElement>
{ {
@@ -96,13 +78,13 @@ namespace LLama.Unittest
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
}; };


index = 0;
uint index = 0;
foreach (var rule in state.Rules) foreach (var rule in state.Rules)
{ {
// compare rule to expected rule // compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
for (uint i = 0; i < rule.Elements.Count; i++)
{ {
var element = rule[(int)i];
var element = rule.Elements[(int)i];
var expectedElement = expectedRules[(int)index]; var expectedElement = expectedRules[(int)index];


// Pretty print error message before asserting // Pretty print error message before asserting
@@ -124,7 +106,7 @@ namespace LLama.Unittest
[Fact] [Fact]
public void ParseExtraComplexGrammar() public void ParseExtraComplexGrammar()
{ {
GrammarParser parsedGrammar = new GrammarParser();
GBNFGrammarParser parsedGrammar = new GBNFGrammarParser();
string grammarBytes = @" string grammarBytes = @"
root ::= (expr ""="" ws term ""\n"")+ root ::= (expr ""="" ws term ""\n"")+
expr ::= term ([-+*/] term)* expr ::= term ([-+*/] term)*
@@ -134,9 +116,10 @@ namespace LLama.Unittest
ws ::= [ \t\n]* ws ::= [ \t\n]*
"; ";


ParseState state = parsedGrammar.Parse(grammarBytes);
var state = parsedGrammar.Parse(grammarBytes, "root");
Assert.Equal(0ul, state.StartRuleIndex);


List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
var expected = new List<KeyValuePair<string, uint>>
{ {
new KeyValuePair<string, uint>("expr", 2), new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_6", 6), new KeyValuePair<string, uint>("expr_6", 6),
@@ -153,27 +136,11 @@ namespace LLama.Unittest
new KeyValuePair<string, uint>("ws_12", 12), new KeyValuePair<string, uint>("ws_12", 12),
}; };


uint index = 0;
foreach (var it in state.SymbolIds)
foreach (var symbol in expected)
{ {
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
var rule = state.Rules[(int)symbol.Value];
Assert.Equal(symbol.Key, rule.Name);
} }
Assert.NotEmpty(state.SymbolIds);



var expectedRules = new List<LLamaGrammarElement> var expectedRules = new List<LLamaGrammarElement>
{ {
@@ -246,13 +213,13 @@ namespace LLama.Unittest
new LLamaGrammarElement(LLamaGrammarElementType.END, 0) new LLamaGrammarElement(LLamaGrammarElementType.END, 0)
}; };


index = 0;
uint index = 0;
foreach (var rule in state.Rules) foreach (var rule in state.Rules)
{ {
// compare rule to expected rule // compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
for (uint i = 0; i < rule.Elements.Count; i++)
{ {
var element = rule[(int)i];
var element = rule.Elements[(int)i];
var expectedElement = expectedRules[(int)index]; var expectedElement = expectedRules[(int)index];


// Pretty print error message before asserting // Pretty print error message before asserting


+ 7
- 6
LLama.Unittest/GrammarTest.cs View File

@@ -1,4 +1,5 @@
using LLama.Common; using LLama.Common;
using LLama.Grammars;
using LLama.Native; using LLama.Native;


namespace LLama.Unittest namespace LLama.Unittest
@@ -26,14 +27,14 @@ namespace LLama.Unittest
[Fact] [Fact]
public void CreateBasicGrammar() public void CreateBasicGrammar()
{ {
var rules = new List<List<LLamaGrammarElement>>
var rules = new List<GrammarRule>
{ {
new()
new GrammarRule("alpha", new[]
{ {
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
}),
}; };


using var handle = SafeLLamaGrammarHandle.Create(rules, 0); using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
@@ -44,15 +45,15 @@ namespace LLama.Unittest
{ {
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar! // we can be confident it's not what the LLM would say if not constrained by the grammar!
var rules = new List<List<LLamaGrammarElement>>
var rules = new List<GrammarRule>
{ {
new()
new GrammarRule("feline", new []
{ {
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
}),
}; };


using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);


+ 3
- 0
LLama/AssemblyAttributes.cs View File

@@ -0,0 +1,3 @@
using System.Runtime.CompilerServices;

[assembly: InternalsVisibleTo("LLama.Unittest")]

+ 20
- 0
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;

namespace LLama.Extensions
{
internal static class IReadOnlyListExtensions
{
public static int? IndexOf<T>(this IReadOnlyList<T> list, T item)
where T : IEquatable<T>
{
for (var i = 0; i < list.Count; i++)
{
if (list[i].Equals(item))
return i;
}

return null;
}
}
}

+ 2
- 0
LLama/GlobalSuppressions.cs View File

@@ -6,3 +6,5 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;


[assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")] [assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")]

[assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")]

+ 0
- 181
LLama/Grammar/ParseState.cs View File

@@ -1,181 +0,0 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;

namespace LLama.Grammar
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
public class ParseState
{
public SortedDictionary<string, uint> SymbolIds { get; } = new SortedDictionary<string, uint>();
public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>();

public IEnumerable<List<LLamaGrammarElement>> CRules()
{
foreach (var rule in Rules)
{
yield return rule;
}
}

public void PrintGrammar(StreamWriter file, ParseState state)
{
try
{
Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>();
foreach (var kv in state.SymbolIds)
{
symbolIdNames[kv.Value] = kv.Key;
}
for (int i = 0, end = state.Rules.Count; i < end; i++)
{
PrintRule(file, (uint)i, state.Rules[i], symbolIdNames);
}
}
catch(Exception err)
{
Console.Error.WriteLine($"\nError printing grammar: {err.Message}");
}
}

public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule)
{
foreach (var elem in rule)
{
switch (elem.Type)
{
case LLamaGrammarElementType.END: file.Write("END"); break;
case LLamaGrammarElementType.ALT: file.Write("ALT"); break;
case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break;
case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break;
case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break;
case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break;
case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break;
}
switch (elem.Type)
{
case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
file.Write($"({elem.Value}) ");
break;
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
case LLamaGrammarElementType.CHAR_ALT:
file.Write("(\"");
PrintGrammarChar(file, elem.Value);
file.Write("\") ");
break;
}
}
file.WriteLine();
}

private void PrintRule(
StreamWriter file,
uint ruleId,
List<LLamaGrammarElement> rule,
Dictionary<uint, string> symbolIdNames)
{
if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END)
{
throw new GrammarFormatException(
$"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}");
}

file.Write($"{symbolIdNames[ruleId]} ::= ");

for (int i = 0, end = rule.Count - 1; i < end; i++)
{
var elem = rule[i];
switch (elem.Type)
{
case LLamaGrammarElementType.END:
throw new GrammarFormatException(
$"Unexpected end of rule: {ruleId}, {i}");
case LLamaGrammarElementType.ALT:
file.Write("| ");
break;
case LLamaGrammarElementType.RULE_REF:
file.Write($"{symbolIdNames[elem.Value]} ");
break;
case LLamaGrammarElementType.CHAR:
file.Write("[");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_NOT:
file.Write("[^");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
}
file.Write("-");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
}
PrintGrammarChar(file, elem.Value);
break;

}

if (IsCharElement(elem))
{
switch (rule[i + 1].Type)
{
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
default:
file.Write("] ");
break;
}
}
}
file.WriteLine();
}

private void PrintGrammarChar(StreamWriter file, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
file.Write((char)c);
}
else
{
// cop out of encoding UTF-8
file.Write($"<U+{c:X4}>");
}
}

private bool IsCharElement(LLamaGrammarElement elem)
{
switch (elem.Type)
{
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
return true;
default:
return false;
}
}
}
}

LLama/Grammar/GrammarParser.cs → LLama/Grammars/GBNFGrammarParser.cs View File

@@ -1,10 +1,11 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using LLama.Exceptions;
using LLama.Native;


namespace LLama.Grammar
namespace LLama.Grammars
{ {
/// <summary> /// <summary>
/// Source: /// Source:
@@ -12,7 +13,7 @@ namespace LLama.Grammar
/// ///
/// The commit hash from URL is the actual commit hash that reflects current C# code. /// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary> /// </summary>
public class GrammarParser
internal sealed class GBNFGrammarParser
{ {
// NOTE: assumes valid utf8 (but checks for overrun) // NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp // copied from llama.cpp
@@ -206,7 +207,7 @@ namespace LLama.Grammar
while (!pos.IsEmpty && pos[0] != '"') while (!pos.IsEmpty && pos[0] != '"')
{ {
var charPair = ParseChar(ref pos); var charPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair));
} }
pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);
} }
@@ -228,13 +229,13 @@ namespace LLama.Grammar
var charPair = ParseChar(ref pos); var charPair = ParseChar(ref pos);
var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;


outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair });
outElements.Add(new LLamaGrammarElement(type, charPair));


if (pos[0] == '-' && pos[1] != ']') if (pos[0] == '-' && pos[1] != ']')
{ {
pos = pos.Slice(1); pos = pos.Slice(1);
var endCharPair = ParseChar(ref pos); var endCharPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair));
} }
} }
pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);
@@ -245,7 +246,7 @@ namespace LLama.Grammar
uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested); pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count; lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
} }
else if (pos[0] == '(') // grouping else if (pos[0] == '(') // grouping
{ {
@@ -255,7 +256,7 @@ namespace LLama.Grammar
pos = ParseAlternates(state, pos, ruleName, subRuleId, true); pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count; lastSymStart = outElements.Count;
// output reference to synthesized rule // output reference to synthesized rule
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
if (pos[0] != ')') if (pos[0] != ')')
{ {
throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
@@ -284,11 +285,11 @@ namespace LLama.Grammar
if (pos[0] == '*' || pos[0] == '+') if (pos[0] == '*' || pos[0] == '+')
{ {
// cause generated rule to recurse // cause generated rule to recurse
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
} }


// mark start of alternate def // mark start of alternate def
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));


if (pos[0] == '+') if (pos[0] == '+')
{ {
@@ -296,13 +297,13 @@ namespace LLama.Grammar
subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
} }


subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));


AddRule(state, subRuleId, subRule); AddRule(state, subRuleId, subRule);


// in original rule, replace previous symbol with reference to generated rule // in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));


pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);


@@ -328,12 +329,12 @@ namespace LLama.Grammar


while (!pos.IsEmpty && pos[0] == '|') while (!pos.IsEmpty && pos[0] == '|')
{ {
rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
pos = ParseSpace(pos.Slice(1), true); pos = ParseSpace(pos.Slice(1), true);
pos = ParseSequence(state, pos, ruleName, rule, isNested); pos = ParseSequence(state, pos, ruleName, rule, isNested);
} }


rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
AddRule(state, ruleId, rule); AddRule(state, ruleId, rule);


return pos; return pos;
@@ -370,19 +371,41 @@ namespace LLama.Grammar
return ParseSpace(pos, true); return ParseSpace(pos, true);
} }


public ParseState Parse(string input)
/// <summary>
/// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a>
/// </summary>
/// <param name="input">The string to parse</param>
/// <param name="startRule">The name of the root rule of this grammar</param>
/// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
/// <returns>A ParseState that can be converted into a grammar for sampling</returns>
public Grammar Parse(string input, string startRule)
{ {
byte[] byteArray = Encoding.UTF8.GetBytes(input);
ReadOnlySpan<byte> src = new ReadOnlySpan<byte>(byteArray);
ParseState state = new ParseState();
ReadOnlySpan<byte> pos = ParseSpace(src, true);
var byteArray = Encoding.UTF8.GetBytes(input);
var state = new ParseState();
var pos = ParseSpace(byteArray, true);


while (!pos.IsEmpty) while (!pos.IsEmpty)
{ {
pos = ParseRule(state, pos); pos = ParseRule(state, pos);
} }


return state;
var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key);
var rules = new List<GrammarRule>();
for (var i = 0; i < state.Rules.Count; i++)
{
var elements = state.Rules[i];
var name = names[(uint)i];
rules.Add(new GrammarRule(name, elements));
}

var startRuleIndex = state.SymbolIds[startRule];
return new Grammar(rules, startRuleIndex);
}

private record ParseState
{
public SortedDictionary<string, uint> SymbolIds { get; } = new();
public List<List<LLamaGrammarElement>> Rules { get; } = new();
} }
} }
} }

+ 151
- 0
LLama/Grammars/Grammar.cs View File

@@ -0,0 +1,151 @@
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using LLama.Native;

namespace LLama.Grammars
{
/// <summary>
/// A grammar is a set of <see cref="GrammarRule"/>s for deciding which characters are valid next. Can be used to constrain
/// output to certain formats - e.g. force the model to output JSON
/// </summary>
public sealed class Grammar
{
/// <summary>
/// Index of the initial rule to start from
/// </summary>
public ulong StartRuleIndex { get; set; }

/// <summary>
/// The rules which make up this grammar
/// </summary>
public IReadOnlyList<GrammarRule> Rules { get; }

/// <summary>
/// Create a new grammar from a set of rules
/// </summary>
/// <param name="rules">The rules which make up this grammar</param>
/// <param name="startRuleIndex">Index of the initial rule to start from</param>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public Grammar(IReadOnlyList<GrammarRule> rules, ulong startRuleIndex)
{
if (startRuleIndex >= (uint)rules.Count)
throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules");

StartRuleIndex = startRuleIndex;
Rules = rules;
}

/// <summary>
/// Create a `SafeLLamaGrammarHandle` instance to use for parsing
/// </summary>
/// <returns></returns>
public SafeLLamaGrammarHandle CreateInstance()
{
return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex);
}

/// <summary>
/// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> into a Grammar
/// </summary>
/// <param name="gbnf">The string to parse</param>
/// <param name="startRule">Name of the start rule of this grammar</param>
/// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
/// <returns>A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling</returns>
public static Grammar Parse(string gbnf, string startRule)
{
var parser = new GBNFGrammarParser();
return parser.Parse(gbnf, startRule);
}

/// <inheritdoc />
public override string ToString()
{
var builder = new StringBuilder();
PrintGrammar(builder);
return builder.ToString();
}

private void PrintGrammar(StringBuilder output)
{
for (var i = 0; i < Rules.Count; i++)
PrintRule(output, (uint)i, Rules[i]);
}

private void PrintRule(StringBuilder output, uint ruleId, GrammarRule rule)
{
output.Append($"{rule.Name} ::= ");

for (int i = 0, end = rule.Elements.Count - 1; i < end; i++)
{
var elem = rule.Elements[i];
switch (elem.Type)
{
case LLamaGrammarElementType.END:
throw new GrammarFormatException($"Unexpected end of rule: {ruleId}, {i}");
case LLamaGrammarElementType.ALT:
output.Append("| ");
break;
case LLamaGrammarElementType.RULE_REF:
output.Append($"{Rules[(int)elem.Value].Name} ");
break;
case LLamaGrammarElementType.CHAR:
output.Append('[');
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_NOT:
output.Append("[^");
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !rule.Elements[i - 1].IsCharElement())
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
}
output.Append('-');
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !rule.Elements[i - 1].IsCharElement())
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
}
PrintGrammarChar(output, elem.Value);
break;

}

if (elem.IsCharElement())
{
switch (rule.Elements[i + 1].Type)
{
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
default:
output.Append("] ");
break;
}
}
}

output.AppendLine();
}

private static void PrintGrammarChar(StringBuilder output, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
output.Append((char)c);
}
else
{
// cop out of encoding UTF-8
output.Append($"<U+{c:X4}>");
}
}
}
}

+ 75
- 0
LLama/Grammars/GrammarRule.cs View File

@@ -0,0 +1,75 @@
using System;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Grammars
{
/// <summary>
/// A single rule in a <see cref="Grammar"/>
/// </summary>
public sealed record GrammarRule
{
/// <summary>
/// Name of this rule
/// </summary>
public string Name { get; }

/// <summary>
/// The elements of this grammar rule
/// </summary>
public IReadOnlyList<LLamaGrammarElement> Elements { get; }

/// <summary>
/// Create a new GrammarRule containing the given elements
/// </summary>
/// <param name="name"></param>
/// <param name="elements"></param>
/// <exception cref="ArgumentException"></exception>
public GrammarRule(string name, IReadOnlyList<LLamaGrammarElement> elements)
{
Validate(elements, name);

Name = name;
Elements = elements;
}

private static void Validate(IReadOnlyList<LLamaGrammarElement> elements, string name)
{
if (elements.Count == 0)
throw new ArgumentException("Cannot create a GrammaRule with zero elements", nameof(elements));
if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END)
throw new ArgumentException("Last grammar element must be END", nameof(elements));

for (var i = 0; i < elements.Count; i++)
{
switch (elements[i].Type)
{
case LLamaGrammarElementType.END:
if (i != elements.Count - 1)
throw new ArgumentException("Found more than one END grammar element", nameof(elements));
continue;

case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !elements[i - 1].IsCharElement())
throw new ArgumentException($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {name},{i}", nameof(elements));
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !elements[i - 1].IsCharElement())
{
throw new ArgumentException($"LLamaGrammarElementType.CHAR_ALT without preceding char: {name},{i}", nameof(elements));
}
break;

case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
break;

default:
throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'");
}
}
}
}
}

+ 51
- 4
LLama/Native/LLamaGrammarElement.cs View File

@@ -1,4 +1,5 @@
using System.Diagnostics;
using System;
using System.Diagnostics;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;


namespace LLama.Native namespace LLama.Native
@@ -51,17 +52,18 @@ namespace LLama.Native
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")] [DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
public readonly struct LLamaGrammarElement
: IEquatable<LLamaGrammarElement>
{ {
/// <summary> /// <summary>
/// The type of this element /// The type of this element
/// </summary> /// </summary>
public LLamaGrammarElementType Type;
public readonly LLamaGrammarElementType Type;


/// <summary> /// <summary>
/// Unicode code point or rule ID /// Unicode code point or rule ID
/// </summary> /// </summary>
public uint Value;
public readonly uint Value;


/// <summary> /// <summary>
/// Construct a new LLamaGrammarElement /// Construct a new LLamaGrammarElement
@@ -73,5 +75,50 @@ namespace LLama.Native
Type = type; Type = type;
Value = value; Value = value;
} }

/// <inheritdoc />
public bool Equals(LLamaGrammarElement other)
{
if (Type != other.Type)
return false;

// No need to compare values for the END rule
if (Type == LLamaGrammarElementType.END)
return true;

return Value == other.Value;
}

/// <inheritdoc />
public override bool Equals(object? obj)
{
return obj is LLamaGrammarElement other && Equals(other);
}

/// <inheritdoc />
public override int GetHashCode()
{
unchecked
{
var hash = 2999;
hash = hash * 7723 + (int)Type;
hash = hash * 7723 + (int)Value;
return hash;
}
}

internal bool IsCharElement()
{
switch (Type)
{
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
return true;
default:
return false;
}
}
} }
} }

+ 4
- 3
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using LLama.Exceptions; using LLama.Exceptions;
using LLama.Grammars;


namespace LLama.Native namespace LLama.Native
{ {
@@ -38,11 +39,11 @@ namespace LLama.Native
/// <param name="start_rule_index">The index (in the outer list) of the start rule</param> /// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index)
{ {
unsafe unsafe
{ {
var totalElements = rules.Sum(a => a.Count);
var totalElements = rules.Sum(a => a.Elements.Count);
var nrules = (ulong)rules.Count; var nrules = (ulong)rules.Count;


// Borrow an array large enough to hold every single element // Borrow an array large enough to hold every single element
@@ -61,7 +62,7 @@ namespace LLama.Native
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex); pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);


// Copy all of the rule elements into the flat array // Copy all of the rule elements into the flat array
foreach (var element in rule)
foreach (var element in rule.Elements)
allElementsPtr[elementIndex++] = element; allElementsPtr[elementIndex++] = element;
} }




Loading…
Cancel
Save