From 3c919b56fec031b51e2c22d3fd10a76ca5d16261 Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 30 Aug 2023 11:23:55 +0300 Subject: [PATCH] Use ReadOnlySpan everywhere instead of ReadOnlyMemeory and instead of returning tuple, reference the ReadOnlySpan. --- LLama/Grammar/GrammarParser.cs | 99 +++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs index 1a990ffb..8ae8209d 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammar/GrammarParser.cs @@ -15,12 +15,11 @@ namespace LLama.Grammar { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp - public (uint, ReadOnlyMemory) DecodeUTF8(ReadOnlyMemory src) + public uint DecodeUTF8(ref ReadOnlySpan src) { - ReadOnlySpan span = src.Span; int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - byte firstByte = span[0]; + byte firstByte = src[0]; byte highbits = (byte)(firstByte >> 4); int len = lookup[highbits]; byte mask = (byte)((1 << (8 - len)) - 1); @@ -31,18 +30,18 @@ namespace LLama.Grammar for (; pos < end && pos < src.Length; pos++) { - value = (uint)((value << 6) + (span[pos] & 0x3F)); + value = (uint)((value << 6) + (src[pos] & 0x3F)); } - ReadOnlyMemory nextSpan = src.Slice(pos); + src = src.Slice(pos); - return (value, nextSpan); + return value; } public uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) { uint nextId = (uint)state.SymbolIds.Count; - string key = src.Slice(0, len).ToString(); + string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); if (state.SymbolIds.TryGetValue(key, out uint existingId)) { @@ -78,18 +77,16 @@ namespace LLama.Grammar return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } - public (uint, ReadOnlyMemory) ParseHex(ReadOnlyMemory src, int size) + public uint ParseHex(ref ReadOnlySpan src, int size) { int pos = 0; int end = size; uint value = 0; - ReadOnlySpan srcSpan = src.Span; - for (; pos < end && pos < src.Length; pos++) { value <<= 4; - byte c = srcSpan[pos]; + byte c = src[pos]; if ('a' <= c && c <= 'f') { value += (uint)(c - 'a' + 10); @@ -110,10 +107,10 @@ namespace LLama.Grammar if (pos != end) { - throw new InvalidOperationException($"Expecting {size} hex chars at {src}"); + throw new InvalidOperationException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}"); } - - return (value, src.Slice(pos)); + src = src.Slice(pos); + return value; } public ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) @@ -147,57 +144,55 @@ namespace LLama.Grammar } if (pos == 0) { - throw new InvalidOperationException($"Expecting name at {src.ToString()}"); + throw new InvalidOperationException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}"); } return src.Slice(pos); } - public (uint, ReadOnlyMemory) ParseChar(ReadOnlyMemory src) + public uint ParseChar(ref ReadOnlySpan src) { - ReadOnlySpan span = src.Span; - - if (span[0] == '\\') + if (src[0] == '\\') { - switch ((char)span[1]) + src = src.Slice(2); + switch ((char)src[1]) { case 'x': - return ParseHex(src.Slice(2), 2); + return ParseHex(ref src, 2); case 'u': - return ParseHex(src.Slice(2), 4); + return ParseHex(ref src, 4); case 'U': - return ParseHex(src.Slice(2), 8); + return ParseHex(ref src, 8); case 't': - return ('\t', src.Slice(2)); + return '\t'; case 'r': - return ('\r', src.Slice(2)); + return '\r'; case 'n': - return ('\n', src.Slice(2)); + return '\n'; case '\\': case '"': case '[': case ']': - return (span[1], src.Slice(2)); + return src[1]; default: - throw new Exception("Unknown escape at " + src.ToString()); + throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray())); } } - else if (!span.IsEmpty) + else if (!src.IsEmpty) { - return DecodeUTF8(src); + return DecodeUTF8(ref src); } throw new Exception("Unexpected end of input"); } public ReadOnlySpan ParseSequence( - ref ParseState state, - ReadOnlyMemory src, + ParseState state, + ReadOnlySpan pos, string ruleName, List outElements, bool isNested) { int lastSymStart = outElements.Count; - var pos = src.Span; while (!pos.IsEmpty) { @@ -208,9 +203,8 @@ namespace LLama.Grammar while (pos[0] != '"') { - var charPair = ParseChar(src); - pos = charPair.Item2.Span; - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 }); + var charPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); } pos = ParseSpace(pos.Slice(1), isNested); } @@ -229,17 +223,16 @@ namespace LLama.Grammar while (pos[0] != ']') { - var charPair = ParseChar(src); - pos = charPair.Item2.Span; + var charPair = ParseChar(ref pos); var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; - outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 }); + outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); if (pos[0] == '-' && pos[1] != ']') { - var endCharPair = ParseChar(src.Slice(1)); - pos = endCharPair.Item2.Span; - outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 }); + pos = pos.Slice(1); + var endCharPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); } } pos = ParseSpace(pos.Slice(1), isNested); @@ -321,9 +314,27 @@ namespace LLama.Grammar return pos; } - public ReadOnlySpan ParseAlternates(ParseState state, ReadOnlySpan pos, string ruleName, uint subRuleId, bool v) + public ReadOnlySpan ParseAlternates( + ParseState state, + ReadOnlySpan src, + string ruleName, + uint ruleId, + bool isNested) { - throw new NotImplementedException(); + var rule = new List(); + ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested); + + while (pos[0] == '|') + { + rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + pos = ParseSpace(pos.Slice(1), true); + pos = ParseSequence(state, pos, ruleName, rule, isNested); + } + + rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + AddRule(state, ruleId, rule); + + return pos; } } }