Browse Source

- Created a higher level `Grammar` class which is immutable and contains a list of grammar rules. This is the main "entry point" to the grammar system.

- Made all the mechanics of grammar parsing (GBNFGrammarParser, ParseState) internal. Just call `Grammar.Parse("whatever")`.
 - Added a `GrammarRule` class which validates elements on construction (this allows constructing grammar without parsing GBNF).
   - It should be impossible for a `GrammarRule` to represent an invalid rule.
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
a70c7170dd
12 changed files with 385 additions and 278 deletions
  1. +5
    -7
      LLama.Examples/NewVersion/GrammarJsonResponse.cs
  2. +22
    -55
      LLama.Unittest/GrammarParserTest.cs
  3. +7
    -6
      LLama.Unittest/GrammarTest.cs
  4. +3
    -0
      LLama/AssemblyAttributes.cs
  5. +20
    -0
      LLama/Extensions/IReadOnlyListExtensions.cs
  6. +2
    -0
      LLama/GlobalSuppressions.cs
  7. +0
    -181
      LLama/Grammar/ParseState.cs
  8. +45
    -22
      LLama/Grammars/GBNFGrammarParser.cs
  9. +151
    -0
      LLama/Grammars/Grammar.cs
  10. +75
    -0
      LLama/Grammars/GrammarRule.cs
  11. +51
    -4
      LLama/Native/LLamaGrammarElement.cs
  12. +4
    -3
      LLama/Native/SafeLLamaGrammarHandle.cs

+ 5
- 7
LLama.Examples/NewVersion/GrammarJsonResponse.cs View File

@@ -1,6 +1,5 @@
using LLama.Common;
using LLama.Grammar;
using LLama.Native;
using LLama.Grammars;

namespace LLama.Examples.NewVersion
{
@@ -8,8 +7,8 @@ namespace LLama.Examples.NewVersion
{
public static void Run()
{
var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim();
var parsedGrammar = new GrammarParser();
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var grammar = Grammar.Parse(gbnf, "root");

Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -22,19 +21,18 @@ namespace LLama.Examples.NewVersion
};
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);
ParseState state = parsedGrammar.Parse(grammarBytes);
using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\"");
Console.ForegroundColor = ConsoleColor.White;

using var grammarInstance = grammar.CreateInstance();
var inferenceParams = new InferenceParams()
{
Temperature = 0.6f,
AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" },
MaxTokens = 50,
Grammar = grammar
Grammar = grammarInstance
};

while (true)


+ 22
- 55
LLama.Unittest/GrammarParserTest.cs View File

@@ -1,8 +1,5 @@
using LLama.Common;
using LLama.Native;
using System.Diagnostics;
using LLama.Grammar;
using Newtonsoft.Json.Linq;
using LLama.Native;
using LLama.Grammars;

namespace LLama.Unittest
{
@@ -17,14 +14,15 @@ namespace LLama.Unittest
[Fact]
public void ParseComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
GBNFGrammarParser parsedGrammar = new GBNFGrammarParser();
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+";

ParseState state = parsedGrammar.Parse(grammarBytes);
var state = parsedGrammar.Parse(grammarBytes, "root");
Assert.Equal(0ul, state.StartRuleIndex);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_5", 5),
@@ -36,27 +34,11 @@ namespace LLama.Unittest
new KeyValuePair<string, uint>("term_7", 7),
};

uint index = 0;
foreach (var it in state.SymbolIds)
foreach (var symbol in expected)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
var rule = state.Rules[(int)symbol.Value];
Assert.Equal(symbol.Key, rule.Name);
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
@@ -96,13 +78,13 @@ namespace LLama.Unittest
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

index = 0;
uint index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
for (uint i = 0; i < rule.Elements.Count; i++)
{
var element = rule[(int)i];
var element = rule.Elements[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
@@ -124,7 +106,7 @@ namespace LLama.Unittest
[Fact]
public void ParseExtraComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
GBNFGrammarParser parsedGrammar = new GBNFGrammarParser();
string grammarBytes = @"
root ::= (expr ""="" ws term ""\n"")+
expr ::= term ([-+*/] term)*
@@ -134,9 +116,10 @@ namespace LLama.Unittest
ws ::= [ \t\n]*
";

ParseState state = parsedGrammar.Parse(grammarBytes);
var state = parsedGrammar.Parse(grammarBytes, "root");
Assert.Equal(0ul, state.StartRuleIndex);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_6", 6),
@@ -153,27 +136,11 @@ namespace LLama.Unittest
new KeyValuePair<string, uint>("ws_12", 12),
};

uint index = 0;
foreach (var it in state.SymbolIds)
foreach (var symbol in expected)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
var rule = state.Rules[(int)symbol.Value];
Assert.Equal(symbol.Key, rule.Name);
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
@@ -246,13 +213,13 @@ namespace LLama.Unittest
new LLamaGrammarElement(LLamaGrammarElementType.END, 0)
};

index = 0;
uint index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
for (uint i = 0; i < rule.Elements.Count; i++)
{
var element = rule[(int)i];
var element = rule.Elements[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting


+ 7
- 6
LLama.Unittest/GrammarTest.cs View File

@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Grammars;
using LLama.Native;

namespace LLama.Unittest
@@ -26,14 +27,14 @@ namespace LLama.Unittest
[Fact]
public void CreateBasicGrammar()
{
var rules = new List<List<LLamaGrammarElement>>
var rules = new List<GrammarRule>
{
new()
new GrammarRule("alpha", new[]
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
}),
};

using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
@@ -44,15 +45,15 @@ namespace LLama.Unittest
{
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
var rules = new List<List<LLamaGrammarElement>>
var rules = new List<GrammarRule>
{
new()
new GrammarRule("feline", new []
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'c'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 't'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
},
}),
};

using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);


+ 3
- 0
LLama/AssemblyAttributes.cs View File

@@ -0,0 +1,3 @@
using System.Runtime.CompilerServices;

[assembly: InternalsVisibleTo("LLama.Unittest")]

+ 20
- 0
LLama/Extensions/IReadOnlyListExtensions.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;

namespace LLama.Extensions
{
internal static class IReadOnlyListExtensions
{
public static int? IndexOf<T>(this IReadOnlyList<T> list, T item)
where T : IEquatable<T>
{
for (var i = 0; i < list.Count; i++)
{
if (list[i].Equals(item))
return i;
}

return null;
}
}
}

+ 2
- 0
LLama/GlobalSuppressions.cs View File

@@ -6,3 +6,5 @@
using System.Diagnostics.CodeAnalysis;

[assembly: SuppressMessage("Interoperability", "CA1401:P/Invokes should not be visible", Justification = "LLamaSharp intentionally exports the native llama.cpp API")]

[assembly: SuppressMessage("Style", "IDE0070:Use 'System.HashCode'", Justification = "Not compatible with netstandard2.0")]

+ 0
- 181
LLama/Grammar/ParseState.cs View File

@@ -1,181 +0,0 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;

namespace LLama.Grammar
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
public class ParseState
{
public SortedDictionary<string, uint> SymbolIds { get; } = new SortedDictionary<string, uint>();
public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>();

public IEnumerable<List<LLamaGrammarElement>> CRules()
{
foreach (var rule in Rules)
{
yield return rule;
}
}

public void PrintGrammar(StreamWriter file, ParseState state)
{
try
{
Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>();
foreach (var kv in state.SymbolIds)
{
symbolIdNames[kv.Value] = kv.Key;
}
for (int i = 0, end = state.Rules.Count; i < end; i++)
{
PrintRule(file, (uint)i, state.Rules[i], symbolIdNames);
}
}
catch(Exception err)
{
Console.Error.WriteLine($"\nError printing grammar: {err.Message}");
}
}

public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule)
{
foreach (var elem in rule)
{
switch (elem.Type)
{
case LLamaGrammarElementType.END: file.Write("END"); break;
case LLamaGrammarElementType.ALT: file.Write("ALT"); break;
case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break;
case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break;
case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break;
case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break;
case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break;
}
switch (elem.Type)
{
case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
file.Write($"({elem.Value}) ");
break;
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
case LLamaGrammarElementType.CHAR_ALT:
file.Write("(\"");
PrintGrammarChar(file, elem.Value);
file.Write("\") ");
break;
}
}
file.WriteLine();
}

private void PrintRule(
StreamWriter file,
uint ruleId,
List<LLamaGrammarElement> rule,
Dictionary<uint, string> symbolIdNames)
{
if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END)
{
throw new GrammarFormatException(
$"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}");
}

file.Write($"{symbolIdNames[ruleId]} ::= ");

for (int i = 0, end = rule.Count - 1; i < end; i++)
{
var elem = rule[i];
switch (elem.Type)
{
case LLamaGrammarElementType.END:
throw new GrammarFormatException(
$"Unexpected end of rule: {ruleId}, {i}");
case LLamaGrammarElementType.ALT:
file.Write("| ");
break;
case LLamaGrammarElementType.RULE_REF:
file.Write($"{symbolIdNames[elem.Value]} ");
break;
case LLamaGrammarElementType.CHAR:
file.Write("[");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_NOT:
file.Write("[^");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
}
file.Write("-");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
}
PrintGrammarChar(file, elem.Value);
break;

}

if (IsCharElement(elem))
{
switch (rule[i + 1].Type)
{
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
default:
file.Write("] ");
break;
}
}
}
file.WriteLine();
}

private void PrintGrammarChar(StreamWriter file, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
file.Write((char)c);
}
else
{
// cop out of encoding UTF-8
file.Write($"<U+{c:X4}>");
}
}

private bool IsCharElement(LLamaGrammarElement elem)
{
switch (elem.Type)
{
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
return true;
default:
return false;
}
}
}
}

LLama/Grammar/GrammarParser.cs → LLama/Grammars/GBNFGrammarParser.cs View File

@@ -1,10 +1,11 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using LLama.Exceptions;
using LLama.Native;

namespace LLama.Grammar
namespace LLama.Grammars
{
/// <summary>
/// Source:
@@ -12,7 +13,7 @@ namespace LLama.Grammar
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
public class GrammarParser
internal sealed class GBNFGrammarParser
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
@@ -206,7 +207,7 @@ namespace LLama.Grammar
while (!pos.IsEmpty && pos[0] != '"')
{
var charPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR, charPair));
}
pos = ParseSpace(pos.Slice(1), isNested);
}
@@ -228,13 +229,13 @@ namespace LLama.Grammar
var charPair = ParseChar(ref pos);
var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;

outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair });
outElements.Add(new LLamaGrammarElement(type, charPair));

if (pos[0] == '-' && pos[1] != ']')
{
pos = pos.Slice(1);
var endCharPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, endCharPair));
}
}
pos = ParseSpace(pos.Slice(1), isNested);
@@ -245,7 +246,7 @@ namespace LLama.Grammar
uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
}
else if (pos[0] == '(') // grouping
{
@@ -255,7 +256,7 @@ namespace LLama.Grammar
pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count;
// output reference to synthesized rule
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
if (pos[0] != ')')
{
throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
@@ -284,11 +285,11 @@ namespace LLama.Grammar
if (pos[0] == '*' || pos[0] == '+')
{
// cause generated rule to recurse
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));
}

// mark start of alternate def
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));

if (pos[0] == '+')
{
@@ -296,13 +297,13 @@ namespace LLama.Grammar
subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
}

subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));

AddRule(state, subRuleId, subRule);

// in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, subRuleId));

pos = ParseSpace(pos.Slice(1), isNested);

@@ -328,12 +329,12 @@ namespace LLama.Grammar

while (!pos.IsEmpty && pos[0] == '|')
{
rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0));
pos = ParseSpace(pos.Slice(1), true);
pos = ParseSequence(state, pos, ruleName, rule, isNested);
}

rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
AddRule(state, ruleId, rule);

return pos;
@@ -370,19 +371,41 @@ namespace LLama.Grammar
return ParseSpace(pos, true);
}

public ParseState Parse(string input)
/// <summary>
/// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a>
/// </summary>
/// <param name="input">The string to parse</param>
/// <param name="startRule">The name of the root rule of this grammar</param>
/// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
/// <returns>A ParseState that can be converted into a grammar for sampling</returns>
public Grammar Parse(string input, string startRule)
{
byte[] byteArray = Encoding.UTF8.GetBytes(input);
ReadOnlySpan<byte> src = new ReadOnlySpan<byte>(byteArray);
ParseState state = new ParseState();
ReadOnlySpan<byte> pos = ParseSpace(src, true);
var byteArray = Encoding.UTF8.GetBytes(input);
var state = new ParseState();
var pos = ParseSpace(byteArray, true);

while (!pos.IsEmpty)
{
pos = ParseRule(state, pos);
}

return state;
var names = state.SymbolIds.ToDictionary(a => a.Value, a => a.Key);
var rules = new List<GrammarRule>();
for (var i = 0; i < state.Rules.Count; i++)
{
var elements = state.Rules[i];
var name = names[(uint)i];
rules.Add(new GrammarRule(name, elements));
}

var startRuleIndex = state.SymbolIds[startRule];
return new Grammar(rules, startRuleIndex);
}

private record ParseState
{
public SortedDictionary<string, uint> SymbolIds { get; } = new();
public List<List<LLamaGrammarElement>> Rules { get; } = new();
}
}
}

+ 151
- 0
LLama/Grammars/Grammar.cs View File

@@ -0,0 +1,151 @@
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using LLama.Native;

namespace LLama.Grammars
{
/// <summary>
/// A grammar is a set of <see cref="GrammarRule"/>s for deciding which characters are valid next. Can be used to constrain
/// output to certain formats - e.g. force the model to output JSON
/// </summary>
public sealed class Grammar
{
/// <summary>
/// Index of the initial rule to start from
/// </summary>
public ulong StartRuleIndex { get; set; }

/// <summary>
/// The rules which make up this grammar
/// </summary>
public IReadOnlyList<GrammarRule> Rules { get; }

/// <summary>
/// Create a new grammar from a set of rules
/// </summary>
/// <param name="rules">The rules which make up this grammar</param>
/// <param name="startRuleIndex">Index of the initial rule to start from</param>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public Grammar(IReadOnlyList<GrammarRule> rules, ulong startRuleIndex)
{
if (startRuleIndex >= (uint)rules.Count)
throw new ArgumentOutOfRangeException(nameof(startRuleIndex), "startRule must be less than the number of rules");

StartRuleIndex = startRuleIndex;
Rules = rules;
}

/// <summary>
/// Create a `SafeLLamaGrammarHandle` instance to use for parsing
/// </summary>
/// <returns></returns>
public SafeLLamaGrammarHandle CreateInstance()
{
return SafeLLamaGrammarHandle.Create(Rules, StartRuleIndex);
}

/// <summary>
/// Parse a string of <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">GGML BNF</a> into a Grammar
/// </summary>
/// <param name="gbnf">The string to parse</param>
/// <param name="startRule">Name of the start rule of this grammar</param>
/// <exception cref="GrammarFormatException">Thrown if input is malformed</exception>
/// <returns>A Grammar which can be converted into a SafeLLamaGrammarHandle for sampling</returns>
public static Grammar Parse(string gbnf, string startRule)
{
var parser = new GBNFGrammarParser();
return parser.Parse(gbnf, startRule);
}

/// <inheritdoc />
public override string ToString()
{
var builder = new StringBuilder();
PrintGrammar(builder);
return builder.ToString();
}

private void PrintGrammar(StringBuilder output)
{
for (var i = 0; i < Rules.Count; i++)
PrintRule(output, (uint)i, Rules[i]);
}

private void PrintRule(StringBuilder output, uint ruleId, GrammarRule rule)
{
output.Append($"{rule.Name} ::= ");

for (int i = 0, end = rule.Elements.Count - 1; i < end; i++)
{
var elem = rule.Elements[i];
switch (elem.Type)
{
case LLamaGrammarElementType.END:
throw new GrammarFormatException($"Unexpected end of rule: {ruleId}, {i}");
case LLamaGrammarElementType.ALT:
output.Append("| ");
break;
case LLamaGrammarElementType.RULE_REF:
output.Append($"{Rules[(int)elem.Value].Name} ");
break;
case LLamaGrammarElementType.CHAR:
output.Append('[');
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_NOT:
output.Append("[^");
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !rule.Elements[i - 1].IsCharElement())
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
}
output.Append('-');
PrintGrammarChar(output, elem.Value);
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !rule.Elements[i - 1].IsCharElement())
{
throw new GrammarFormatException(
$"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
}
PrintGrammarChar(output, elem.Value);
break;

}

if (elem.IsCharElement())
{
switch (rule.Elements[i + 1].Type)
{
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
default:
output.Append("] ");
break;
}
}
}

output.AppendLine();
}

private static void PrintGrammarChar(StringBuilder output, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
output.Append((char)c);
}
else
{
// cop out of encoding UTF-8
output.Append($"<U+{c:X4}>");
}
}
}
}

+ 75
- 0
LLama/Grammars/GrammarRule.cs View File

@@ -0,0 +1,75 @@
using System;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Grammars
{
/// <summary>
/// A single rule in a <see cref="Grammar"/>
/// </summary>
public sealed record GrammarRule
{
/// <summary>
/// Name of this rule
/// </summary>
public string Name { get; }

/// <summary>
/// The elements of this grammar rule
/// </summary>
public IReadOnlyList<LLamaGrammarElement> Elements { get; }

/// <summary>
/// Create a new GrammarRule containing the given elements
/// </summary>
/// <param name="name"></param>
/// <param name="elements"></param>
/// <exception cref="ArgumentException"></exception>
public GrammarRule(string name, IReadOnlyList<LLamaGrammarElement> elements)
{
Validate(elements, name);

Name = name;
Elements = elements;
}

private static void Validate(IReadOnlyList<LLamaGrammarElement> elements, string name)
{
if (elements.Count == 0)
throw new ArgumentException("Cannot create a GrammaRule with zero elements", nameof(elements));
if (elements[elements.Count - 1].Type != LLamaGrammarElementType.END)
throw new ArgumentException("Last grammar element must be END", nameof(elements));

for (var i = 0; i < elements.Count; i++)
{
switch (elements[i].Type)
{
case LLamaGrammarElementType.END:
if (i != elements.Count - 1)
throw new ArgumentException("Found more than one END grammar element", nameof(elements));
continue;

case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !elements[i - 1].IsCharElement())
throw new ArgumentException($"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {name},{i}", nameof(elements));
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !elements[i - 1].IsCharElement())
{
throw new ArgumentException($"LLamaGrammarElementType.CHAR_ALT without preceding char: {name},{i}", nameof(elements));
}
break;

case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
break;

default:
throw new ArgumentException($"Unknown grammar element type: '{elements[i].Type}'");
}
}
}
}
}

+ 51
- 4
LLama/Native/LLamaGrammarElement.cs View File

@@ -1,4 +1,5 @@
using System.Diagnostics;
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace LLama.Native
@@ -51,17 +52,18 @@ namespace LLama.Native
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
public readonly struct LLamaGrammarElement
: IEquatable<LLamaGrammarElement>
{
/// <summary>
/// The type of this element
/// </summary>
public LLamaGrammarElementType Type;
public readonly LLamaGrammarElementType Type;

/// <summary>
/// Unicode code point or rule ID
/// </summary>
public uint Value;
public readonly uint Value;

/// <summary>
/// Construct a new LLamaGrammarElement
@@ -73,5 +75,50 @@ namespace LLama.Native
Type = type;
Value = value;
}

/// <inheritdoc />
public bool Equals(LLamaGrammarElement other)
{
if (Type != other.Type)
return false;

// No need to compare values for the END rule
if (Type == LLamaGrammarElementType.END)
return true;

return Value == other.Value;
}

/// <inheritdoc />
public override bool Equals(object? obj)
{
return obj is LLamaGrammarElement other && Equals(other);
}

/// <inheritdoc />
public override int GetHashCode()
{
unchecked
{
var hash = 2999;
hash = hash * 7723 + (int)Type;
hash = hash * 7723 + (int)Value;
return hash;
}
}

internal bool IsCharElement()
{
switch (Type)
{
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
return true;
default:
return false;
}
}
}
}

+ 4
- 3
LLama/Native/SafeLLamaGrammarHandle.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using LLama.Exceptions;
using LLama.Grammars;

namespace LLama.Native
{
@@ -38,11 +39,11 @@ namespace LLama.Native
/// <param name="start_rule_index">The index (in the outer list) of the start rule</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaGrammarHandle Create(IReadOnlyList<IReadOnlyList<LLamaGrammarElement>> rules, ulong start_rule_index)
public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ulong start_rule_index)
{
unsafe
{
var totalElements = rules.Sum(a => a.Count);
var totalElements = rules.Sum(a => a.Elements.Count);
var nrules = (ulong)rules.Count;

// Borrow an array large enough to hold every single element
@@ -61,7 +62,7 @@ namespace LLama.Native
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);

// Copy all of the rule elements into the flat array
foreach (var element in rule)
foreach (var element in rule.Elements)
allElementsPtr[elementIndex++] = element;
}



Loading…
Cancel
Save