Browse Source

[WIP] Translating the GrammarParser

tags/v0.5.1
Mihai 2 years ago
parent
commit
7f31276bdf
2 changed files with 354 additions and 0 deletions
  1. +328
    -0
      LLama/Grammar/GrammarParser.cs
  2. +26
    -0
      LLama/Grammar/ParseState.cs

+ 328
- 0
LLama/Grammar/GrammarParser.cs View File

@@ -0,0 +1,328 @@
using LLama.Native;
using System;
using System.Collections.Generic;

namespace LLama.Grammar
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
internal class GrammarParser
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
public Tuple<uint, ReadOnlyMemory<char>> DecodeUTF8(ReadOnlyMemory<char> src)
{
ReadOnlySpan<char> span = src.Span;
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };

byte firstByte = (byte)span[0];
byte highbits = (byte)(firstByte >> 4);
int len = lookup[highbits];
byte mask = (byte)((1 << (8 - len)) - 1);
uint value = (uint)(firstByte & mask);

int end = len;
int pos = 1;

for (; pos < end && pos < src.Length; pos++)
{
value = (uint)((value << 6) + ((byte)span[pos] & 0x3F));
}

ReadOnlyMemory<char> nextSpan = src.Slice(pos);

return new Tuple<uint, ReadOnlyMemory<char>>(value, nextSpan);
}

public uint GetSymbolId(ParseState state, ReadOnlySpan<char> src, int len)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = src.Slice(0, len).ToString();

if (state.SymbolIds.TryGetValue(key, out uint existingId))
{
return existingId;
}
else
{
state.SymbolIds[key] = nextId;
return nextId;
}
}

public uint GenerateSymbolId(ParseState state, string baseName)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = $"{baseName}_{nextId}";
state.SymbolIds[key] = nextId;
return nextId;
}

public void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
{
while (state.Rules.Count <= ruleId)
{
state.Rules.Add(new List<LLamaGrammarElement>());
}

state.Rules[(int)ruleId] = rule;
}

public bool IsWordChar(char c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}

public Tuple<uint, ReadOnlyMemory<char>> ParseHex(ReadOnlyMemory<char> src, int size)
{
int pos = 0;
int end = size;
uint value = 0;

ReadOnlySpan<char> srcSpan = src.Span;

for (; pos < end && pos < src.Length; pos++)
{
value <<= 4;
char c = srcSpan[pos];
if ('a' <= c && c <= 'f')
{
value += (uint)(c - 'a' + 10);
}
else if ('A' <= c && c <= 'F')
{
value += (uint)(c - 'A' + 10);
}
else if ('0' <= c && c <= '9')
{
value += (uint)(c - '0');
}
else
{
break;
}
}

if (pos != end)
{
throw new InvalidOperationException($"Expecting {size} hex chars at {src.ToString()}");
}

return new Tuple<uint, ReadOnlyMemory<char>>(value, src.Slice(pos));
}

public ReadOnlySpan<char> ParseSpace(ReadOnlySpan<char> src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
(src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' ||
(newlineOk && (src[pos] == '\r' || src[pos] == '\n'))))
{
if (src[pos] == '#')
{
while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n')
{
pos++;
}
}
else
{
pos++;
}
}
return src.Slice(pos);
}

public ReadOnlySpan<char> ParseName(ReadOnlySpan<char> src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
{
pos++;
}
if (pos == 0)
{
throw new InvalidOperationException($"Expecting name at {src.ToString()}");
}
return src.Slice(pos);
}

public Tuple<uint, ReadOnlyMemory<char>> ParseChar(ReadOnlyMemory<char> src)
{
ReadOnlySpan<char> span = src.Span;

if (span[0] == '\\')
{
switch (span[1])
{
case 'x':
return ParseHex(src.Slice(2), 2);
case 'u':
return ParseHex(src.Slice(2), 4);
case 'U':
return ParseHex(src.Slice(2), 8);
case 't':
return new Tuple<uint, ReadOnlyMemory<char>>('\t', src.Slice(2));
case 'r':
return new Tuple<uint, ReadOnlyMemory<char>>('\r', src.Slice(2));
case 'n':
return new Tuple<uint, ReadOnlyMemory<char>>('\n', src.Slice(2));
case '\\':
case '"':
case '[':
case ']':
return new Tuple<uint, ReadOnlyMemory<char>>(span[1], src.Slice(2));
default:
throw new Exception("Unknown escape at " + src.ToString());
}
}
else if (!span.IsEmpty)
{
return DecodeUTF8(src);
}

throw new Exception("Unexpected end of input");
}

public ReadOnlySpan<char> ParseSequence(
ref ParseState state,
ReadOnlyMemory<char> src,
string ruleName,
List<LLamaGrammarElement> outElements,
bool isNested)
{
int lastSymStart = outElements.Count;
var pos = src.Span;

while (!pos.IsEmpty)
{
if (pos[0] == '"') // literal string
{
pos = pos.Slice(1);
lastSymStart = outElements.Count;

while (pos[0] != '"')
{
var charPair = ParseChar(src);
pos = charPair.Item2.Span;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 });
}
pos = ParseSpace(pos.Slice(1), isNested);
}
else if (pos[0] == '[') // char range(s)
{
pos = pos.Slice(1);
var startType = LLamaGrammarElementType.CHAR;

if (pos[0] == '^')
{
pos = pos.Slice(1);
startType = LLamaGrammarElementType.CHAR_NOT;
}

lastSymStart = outElements.Count;

while (pos[0] != ']')
{
var charPair = ParseChar(src);
pos = charPair.Item2.Span;
var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;

outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 });

if (pos[0] == '-' && pos[1] != ']')
{
var endCharPair = ParseChar(src.Slice(1));
pos = endCharPair.Item2.Span;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 });
}
}
pos = ParseSpace(pos.Slice(1), isNested);
}
else if (IsWordChar(pos[0])) // rule reference
{
var nameEnd = ParseName(pos);
uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId });
}
else if (pos[0] == '(') // grouping
{
// parse nested alternates into synthesized rule
pos = ParseSpace(pos.Slice(1), true);
uint subRuleId = GenerateSymbolId(state, ruleName);
pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count;
// output reference to synthesized rule
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
if (pos[0] != ')')
{
throw new Exception($"Expecting ')' at {new string(pos.ToArray())}");
}
pos = ParseSpace(pos.Slice(1), isNested);
}
else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
{
if (lastSymStart == outElements.Count)
{
throw new Exception($"Expecting preceding item to */+/? at {new string(pos.ToArray())}");
}

// apply transformation to previous symbol (lastSymStart to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint subRuleId = GenerateSymbolId(state, ruleName);

List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();

// add preceding symbol to generated rule
subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));

if (pos[0] == '*' || pos[0] == '+')
{
// cause generated rule to recurse
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
}

// mark start of alternate def
subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });

if (pos[0] == '+')
{
// add preceding symbol as alternate only for '+' (otherwise empty)
subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
}

subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });

AddRule(state, subRuleId, subRule);

// in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });

pos = ParseSpace(pos.Slice(1), isNested);

}
else
{
break;
}
}

return pos;
}

public ReadOnlySpan<char> ParseAlternates(ParseState state, ReadOnlySpan<char> pos, string ruleName, uint subRuleId, bool v)
{
throw new NotImplementedException();
}
}
}

+ 26
- 0
LLama/Grammar/ParseState.cs View File

@@ -0,0 +1,26 @@
using LLama.Native;
using System;
using System.Collections.Generic;

namespace LLama.Grammar
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
internal class ParseState
{
public Dictionary<string, uint> SymbolIds { get; } = new Dictionary<string, uint>();
public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>();

public IEnumerable<List<LLamaGrammarElement>> CRules()
{
foreach (var rule in Rules)
{
yield return rule;
}
}
}
}

Loading…
Cancel
Save