| @@ -15,12 +15,11 @@ namespace LLama.Grammar | |||
| { | |||
| // NOTE: assumes valid utf8 (but checks for overrun) | |||
| // copied from llama.cpp | |||
| public (uint, ReadOnlyMemory<byte>) DecodeUTF8(ReadOnlyMemory<byte> src) | |||
| public uint DecodeUTF8(ref ReadOnlySpan<byte> src) | |||
| { | |||
| ReadOnlySpan<byte> span = src.Span; | |||
| int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | |||
| byte firstByte = span[0]; | |||
| byte firstByte = src[0]; | |||
| byte highbits = (byte)(firstByte >> 4); | |||
| int len = lookup[highbits]; | |||
| byte mask = (byte)((1 << (8 - len)) - 1); | |||
| @@ -31,18 +30,18 @@ namespace LLama.Grammar | |||
| for (; pos < end && pos < src.Length; pos++) | |||
| { | |||
| value = (uint)((value << 6) + (span[pos] & 0x3F)); | |||
| value = (uint)((value << 6) + (src[pos] & 0x3F)); | |||
| } | |||
| ReadOnlyMemory<byte> nextSpan = src.Slice(pos); | |||
| src = src.Slice(pos); | |||
| return (value, nextSpan); | |||
| return value; | |||
| } | |||
| public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) | |||
| { | |||
| uint nextId = (uint)state.SymbolIds.Count; | |||
| string key = src.Slice(0, len).ToString(); | |||
| string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); | |||
| if (state.SymbolIds.TryGetValue(key, out uint existingId)) | |||
| { | |||
| @@ -78,18 +77,16 @@ namespace LLama.Grammar | |||
| return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | |||
| } | |||
| public (uint, ReadOnlyMemory<byte>) ParseHex(ReadOnlyMemory<byte> src, int size) | |||
| public uint ParseHex(ref ReadOnlySpan<byte> src, int size) | |||
| { | |||
| int pos = 0; | |||
| int end = size; | |||
| uint value = 0; | |||
| ReadOnlySpan<byte> srcSpan = src.Span; | |||
| for (; pos < end && pos < src.Length; pos++) | |||
| { | |||
| value <<= 4; | |||
| byte c = srcSpan[pos]; | |||
| byte c = src[pos]; | |||
| if ('a' <= c && c <= 'f') | |||
| { | |||
| value += (uint)(c - 'a' + 10); | |||
| @@ -110,10 +107,10 @@ namespace LLama.Grammar | |||
| if (pos != end) | |||
| { | |||
| throw new InvalidOperationException($"Expecting {size} hex chars at {src}"); | |||
| throw new InvalidOperationException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}"); | |||
| } | |||
| return (value, src.Slice(pos)); | |||
| src = src.Slice(pos); | |||
| return value; | |||
| } | |||
| public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | |||
| @@ -147,57 +144,55 @@ namespace LLama.Grammar | |||
| } | |||
| if (pos == 0) | |||
| { | |||
| throw new InvalidOperationException($"Expecting name at {src.ToString()}"); | |||
| throw new InvalidOperationException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}"); | |||
| } | |||
| return src.Slice(pos); | |||
| } | |||
| public (uint, ReadOnlyMemory<byte>) ParseChar(ReadOnlyMemory<byte> src) | |||
| public uint ParseChar(ref ReadOnlySpan<byte> src) | |||
| { | |||
| ReadOnlySpan<byte> span = src.Span; | |||
| if (span[0] == '\\') | |||
| if (src[0] == '\\') | |||
| { | |||
| switch ((char)span[1]) | |||
| src = src.Slice(2); | |||
| switch ((char)src[1]) | |||
| { | |||
| case 'x': | |||
| return ParseHex(src.Slice(2), 2); | |||
| return ParseHex(ref src, 2); | |||
| case 'u': | |||
| return ParseHex(src.Slice(2), 4); | |||
| return ParseHex(ref src, 4); | |||
| case 'U': | |||
| return ParseHex(src.Slice(2), 8); | |||
| return ParseHex(ref src, 8); | |||
| case 't': | |||
| return ('\t', src.Slice(2)); | |||
| return '\t'; | |||
| case 'r': | |||
| return ('\r', src.Slice(2)); | |||
| return '\r'; | |||
| case 'n': | |||
| return ('\n', src.Slice(2)); | |||
| return '\n'; | |||
| case '\\': | |||
| case '"': | |||
| case '[': | |||
| case ']': | |||
| return (span[1], src.Slice(2)); | |||
| return src[1]; | |||
| default: | |||
| throw new Exception("Unknown escape at " + src.ToString()); | |||
| throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray())); | |||
| } | |||
| } | |||
| else if (!span.IsEmpty) | |||
| else if (!src.IsEmpty) | |||
| { | |||
| return DecodeUTF8(src); | |||
| return DecodeUTF8(ref src); | |||
| } | |||
| throw new Exception("Unexpected end of input"); | |||
| } | |||
| public ReadOnlySpan<byte> ParseSequence( | |||
| ref ParseState state, | |||
| ReadOnlyMemory<byte> src, | |||
| ParseState state, | |||
| ReadOnlySpan<byte> pos, | |||
| string ruleName, | |||
| List<LLamaGrammarElement> outElements, | |||
| bool isNested) | |||
| { | |||
| int lastSymStart = outElements.Count; | |||
| var pos = src.Span; | |||
| while (!pos.IsEmpty) | |||
| { | |||
| @@ -208,9 +203,8 @@ namespace LLama.Grammar | |||
| while (pos[0] != '"') | |||
| { | |||
| var charPair = ParseChar(src); | |||
| pos = charPair.Item2.Span; | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 }); | |||
| var charPair = ParseChar(ref pos); | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); | |||
| } | |||
| pos = ParseSpace(pos.Slice(1), isNested); | |||
| } | |||
| @@ -229,17 +223,16 @@ namespace LLama.Grammar | |||
| while (pos[0] != ']') | |||
| { | |||
| var charPair = ParseChar(src); | |||
| pos = charPair.Item2.Span; | |||
| var charPair = ParseChar(ref pos); | |||
| var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; | |||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 }); | |||
| outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); | |||
| if (pos[0] == '-' && pos[1] != ']') | |||
| { | |||
| var endCharPair = ParseChar(src.Slice(1)); | |||
| pos = endCharPair.Item2.Span; | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 }); | |||
| pos = pos.Slice(1); | |||
| var endCharPair = ParseChar(ref pos); | |||
| outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); | |||
| } | |||
| } | |||
| pos = ParseSpace(pos.Slice(1), isNested); | |||
| @@ -321,9 +314,27 @@ namespace LLama.Grammar | |||
| return pos; | |||
| } | |||
| public ReadOnlySpan<byte> ParseAlternates(ParseState state, ReadOnlySpan<byte> pos, string ruleName, uint subRuleId, bool v) | |||
| public ReadOnlySpan<byte> ParseAlternates( | |||
| ParseState state, | |||
| ReadOnlySpan<byte> src, | |||
| string ruleName, | |||
| uint ruleId, | |||
| bool isNested) | |||
| { | |||
| throw new NotImplementedException(); | |||
| var rule = new List<LLamaGrammarElement>(); | |||
| ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested); | |||
| while (pos[0] == '|') | |||
| { | |||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); | |||
| pos = ParseSpace(pos.Slice(1), true); | |||
| pos = ParseSequence(state, pos, ruleName, rule, isNested); | |||
| } | |||
| rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); | |||
| AddRule(state, ruleId, rule); | |||
| return pos; | |||
| } | |||
| } | |||
| } | |||