| @@ -15,12 +15,11 @@ namespace LLama.Grammar | |||||
| { | { | ||||
| // NOTE: assumes valid utf8 (but checks for overrun) | // NOTE: assumes valid utf8 (but checks for overrun) | ||||
| // copied from llama.cpp | // copied from llama.cpp | ||||
| public (uint, ReadOnlyMemory<byte>) DecodeUTF8(ReadOnlyMemory<byte> src) | |||||
| public uint DecodeUTF8(ref ReadOnlySpan<byte> src) | |||||
| { | { | ||||
| ReadOnlySpan<byte> span = src.Span; | |||||
| int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | ||||
| byte firstByte = span[0]; | |||||
| byte firstByte = src[0]; | |||||
| byte highbits = (byte)(firstByte >> 4); | byte highbits = (byte)(firstByte >> 4); | ||||
| int len = lookup[highbits]; | int len = lookup[highbits]; | ||||
| byte mask = (byte)((1 << (8 - len)) - 1); | byte mask = (byte)((1 << (8 - len)) - 1); | ||||
| @@ -31,18 +30,18 @@ namespace LLama.Grammar | |||||
| for (; pos < end && pos < src.Length; pos++) | for (; pos < end && pos < src.Length; pos++) | ||||
| { | { | ||||
| value = (uint)((value << 6) + (span[pos] & 0x3F)); | |||||
| value = (uint)((value << 6) + (src[pos] & 0x3F)); | |||||
| } | } | ||||
| ReadOnlyMemory<byte> nextSpan = src.Slice(pos); | |||||
| src = src.Slice(pos); | |||||
| return (value, nextSpan); | |||||
| return value; | |||||
| } | } | ||||
| public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) | public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) | ||||
| { | { | ||||
| uint nextId = (uint)state.SymbolIds.Count; | uint nextId = (uint)state.SymbolIds.Count; | ||||
| string key = src.Slice(0, len).ToString(); | |||||
| string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); | |||||
| if (state.SymbolIds.TryGetValue(key, out uint existingId)) | if (state.SymbolIds.TryGetValue(key, out uint existingId)) | ||||
| { | { | ||||
| @@ -78,18 +77,16 @@ namespace LLama.Grammar | |||||
| return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | ||||
| } | } | ||||
| public (uint, ReadOnlyMemory<byte>) ParseHex(ReadOnlyMemory<byte> src, int size) | |||||
| public uint ParseHex(ref ReadOnlySpan<byte> src, int size) | |||||
| { | { | ||||
| int pos = 0; | int pos = 0; | ||||
| int end = size; | int end = size; | ||||
| uint value = 0; | uint value = 0; | ||||
| ReadOnlySpan<byte> srcSpan = src.Span; | |||||
| for (; pos < end && pos < src.Length; pos++) | for (; pos < end && pos < src.Length; pos++) | ||||
| { | { | ||||
| value <<= 4; | value <<= 4; | ||||
| byte c = srcSpan[pos]; | |||||
| byte c = src[pos]; | |||||
| if ('a' <= c && c <= 'f') | if ('a' <= c && c <= 'f') | ||||
| { | { | ||||
| value += (uint)(c - 'a' + 10); | value += (uint)(c - 'a' + 10); | ||||
| @@ -110,10 +107,10 @@ namespace LLama.Grammar | |||||
| if (pos != end) | if (pos != end) | ||||
| { | { | ||||
| throw new InvalidOperationException($"Expecting {size} hex chars at {src}"); | |||||
| throw new InvalidOperationException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}"); | |||||
| } | } | ||||
| return (value, src.Slice(pos)); | |||||
| src = src.Slice(pos); | |||||
| return value; | |||||
| } | } | ||||
| public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | ||||
| @@ -147,57 +144,55 @@ namespace LLama.Grammar | |||||
| } | } | ||||
| if (pos == 0) | if (pos == 0) | ||||
| { | { | ||||
| throw new InvalidOperationException($"Expecting name at {src.ToString()}"); | |||||
| throw new InvalidOperationException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}"); | |||||
| } | } | ||||
| return src.Slice(pos); | return src.Slice(pos); | ||||
| } | } | ||||
| public (uint, ReadOnlyMemory<byte>) ParseChar(ReadOnlyMemory<byte> src) | |||||
| public uint ParseChar(ref ReadOnlySpan<byte> src) | |||||
| { | { | ||||
| ReadOnlySpan<byte> span = src.Span; | |||||
| if (span[0] == '\\') | |||||
| if (src[0] == '\\') | |||||
| { | { | ||||
| switch ((char)span[1]) | |||||
| src = src.Slice(2); | |||||
| switch ((char)src[1]) | |||||
| { | { | ||||
| case 'x': | case 'x': | ||||
| return ParseHex(src.Slice(2), 2); | |||||
| return ParseHex(ref src, 2); | |||||
| case 'u': | case 'u': | ||||
| return ParseHex(src.Slice(2), 4); | |||||
| return ParseHex(ref src, 4); | |||||
| case 'U': | case 'U': | ||||
| return ParseHex(src.Slice(2), 8); | |||||
| return ParseHex(ref src, 8); | |||||
| case 't': | case 't': | ||||
| return ('\t', src.Slice(2)); | |||||
| return '\t'; | |||||
| case 'r': | case 'r': | ||||
| return ('\r', src.Slice(2)); | |||||
| return '\r'; | |||||
| case 'n': | case 'n': | ||||
| return ('\n', src.Slice(2)); | |||||
| return '\n'; | |||||
| case '\\': | case '\\': | ||||
| case '"': | case '"': | ||||
| case '[': | case '[': | ||||
| case ']': | case ']': | ||||
| return (span[1], src.Slice(2)); | |||||
| return src[1]; | |||||
| default: | default: | ||||
| throw new Exception("Unknown escape at " + src.ToString()); | |||||
| throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray())); | |||||
| } | } | ||||
| } | } | ||||
| else if (!span.IsEmpty) | |||||
| else if (!src.IsEmpty) | |||||
| { | { | ||||
| return DecodeUTF8(src); | |||||
| return DecodeUTF8(ref src); | |||||
| } | } | ||||
| throw new Exception("Unexpected end of input"); | throw new Exception("Unexpected end of input"); | ||||
| } | } | ||||
| public ReadOnlySpan<byte> ParseSequence( | public ReadOnlySpan<byte> ParseSequence( | ||||
| ref ParseState state, | |||||
| ReadOnlyMemory<byte> src, | |||||
| ParseState state, | |||||
| ReadOnlySpan<byte> pos, | |||||
| string ruleName, | string ruleName, | ||||
| List<LLamaGrammarElement> outElements, | List<LLamaGrammarElement> outElements, | ||||
| bool isNested) | bool isNested) | ||||
| { | { | ||||
| int lastSymStart = outElements.Count; | int lastSymStart = outElements.Count; | ||||
| var pos = src.Span; | |||||
| while (!pos.IsEmpty) | while (!pos.IsEmpty) | ||||
| { | { | ||||
| @@ -208,9 +203,8 @@ namespace LLama.Grammar | |||||
| while (pos[0] != '"') | while (pos[0] != '"') | ||||
| { | { | ||||
| var charPair = ParseChar(src); | |||||
| pos = charPair.Item2.Span; | |||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 }); | |||||
| var charPair = ParseChar(ref pos); | |||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); | |||||
| } | } | ||||
| pos = ParseSpace(pos.Slice(1), isNested); | pos = ParseSpace(pos.Slice(1), isNested); | ||||
| } | } | ||||
| @@ -229,17 +223,16 @@ namespace LLama.Grammar | |||||
| while (pos[0] != ']') | while (pos[0] != ']') | ||||
| { | { | ||||
| var charPair = ParseChar(src); | |||||
| pos = charPair.Item2.Span; | |||||
| var charPair = ParseChar(ref pos); | |||||
| var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | ||||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 }); | |||||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); | |||||
| if (pos[0] == '-' && pos[1] != ']') | if (pos[0] == '-' && pos[1] != ']') | ||||
| { | { | ||||
| var endCharPair = ParseChar(src.Slice(1)); | |||||
| pos = endCharPair.Item2.Span; | |||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 }); | |||||
| pos = pos.Slice(1); | |||||
| var endCharPair = ParseChar(ref pos); | |||||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); | |||||
| } | } | ||||
| } | } | ||||
| pos = ParseSpace(pos.Slice(1), isNested); | pos = ParseSpace(pos.Slice(1), isNested); | ||||
| @@ -321,9 +314,27 @@ namespace LLama.Grammar | |||||
| return pos; | return pos; | ||||
| } | } | ||||
| public ReadOnlySpan<byte> ParseAlternates(ParseState state, ReadOnlySpan<byte> pos, string ruleName, uint subRuleId, bool v) | |||||
| public ReadOnlySpan<byte> ParseAlternates( | |||||
| ParseState state, | |||||
| ReadOnlySpan<byte> src, | |||||
| string ruleName, | |||||
| uint ruleId, | |||||
| bool isNested) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| var rule = new List<LLamaGrammarElement>(); | |||||
| ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested); | |||||
| while (pos[0] == '|') | |||||
| { | |||||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||||
| pos = ParseSpace(pos.Slice(1), true); | |||||
| pos = ParseSequence(state, pos, ruleName, rule, isNested); | |||||
| } | |||||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||||
| AddRule(state, ruleId, rule); | |||||
| return pos; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||