Browse Source

Use ReadOnlySpan everywhere instead of ReadOnlyMemeory and instead of returning tuple, reference the ReadOnlySpan.

tags/v0.5.1
Mihai 2 years ago
parent
commit
3c919b56fe
1 changed files with 55 additions and 44 deletions
  1. +55
    -44
      LLama/Grammar/GrammarParser.cs

+ 55
- 44
LLama/Grammar/GrammarParser.cs View File

@@ -15,12 +15,11 @@ namespace LLama.Grammar
{ {
// NOTE: assumes valid utf8 (but checks for overrun) // NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp // copied from llama.cpp
public (uint, ReadOnlyMemory<byte>) DecodeUTF8(ReadOnlyMemory<byte> src)
public uint DecodeUTF8(ref ReadOnlySpan<byte> src)
{ {
ReadOnlySpan<byte> span = src.Span;
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };


byte firstByte = span[0];
byte firstByte = src[0];
byte highbits = (byte)(firstByte >> 4); byte highbits = (byte)(firstByte >> 4);
int len = lookup[highbits]; int len = lookup[highbits];
byte mask = (byte)((1 << (8 - len)) - 1); byte mask = (byte)((1 << (8 - len)) - 1);
@@ -31,18 +30,18 @@ namespace LLama.Grammar


for (; pos < end && pos < src.Length; pos++) for (; pos < end && pos < src.Length; pos++)
{ {
value = (uint)((value << 6) + (span[pos] & 0x3F));
value = (uint)((value << 6) + (src[pos] & 0x3F));
} }


ReadOnlyMemory<byte> nextSpan = src.Slice(pos);
src = src.Slice(pos);


return (value, nextSpan);
return value;
} }


public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{ {
uint nextId = (uint)state.SymbolIds.Count; uint nextId = (uint)state.SymbolIds.Count;
string key = src.Slice(0, len).ToString();
string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray());


if (state.SymbolIds.TryGetValue(key, out uint existingId)) if (state.SymbolIds.TryGetValue(key, out uint existingId))
{ {
@@ -78,18 +77,16 @@ namespace LLama.Grammar
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
} }


public (uint, ReadOnlyMemory<byte>) ParseHex(ReadOnlyMemory<byte> src, int size)
public uint ParseHex(ref ReadOnlySpan<byte> src, int size)
{ {
int pos = 0; int pos = 0;
int end = size; int end = size;
uint value = 0; uint value = 0;


ReadOnlySpan<byte> srcSpan = src.Span;

for (; pos < end && pos < src.Length; pos++) for (; pos < end && pos < src.Length; pos++)
{ {
value <<= 4; value <<= 4;
byte c = srcSpan[pos];
byte c = src[pos];
if ('a' <= c && c <= 'f') if ('a' <= c && c <= 'f')
{ {
value += (uint)(c - 'a' + 10); value += (uint)(c - 'a' + 10);
@@ -110,10 +107,10 @@ namespace LLama.Grammar


if (pos != end) if (pos != end)
{ {
throw new InvalidOperationException($"Expecting {size} hex chars at {src}");
throw new InvalidOperationException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}");
} }
return (value, src.Slice(pos));
src = src.Slice(pos);
return value;
} }


public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
@@ -147,57 +144,55 @@ namespace LLama.Grammar
} }
if (pos == 0) if (pos == 0)
{ {
throw new InvalidOperationException($"Expecting name at {src.ToString()}");
throw new InvalidOperationException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}");
} }
return src.Slice(pos); return src.Slice(pos);
} }


public (uint, ReadOnlyMemory<byte>) ParseChar(ReadOnlyMemory<byte> src)
public uint ParseChar(ref ReadOnlySpan<byte> src)
{ {
ReadOnlySpan<byte> span = src.Span;

if (span[0] == '\\')
if (src[0] == '\\')
{ {
switch ((char)span[1])
src = src.Slice(2);
switch ((char)src[1])
{ {
case 'x': case 'x':
return ParseHex(src.Slice(2), 2);
return ParseHex(ref src, 2);
case 'u': case 'u':
return ParseHex(src.Slice(2), 4);
return ParseHex(ref src, 4);
case 'U': case 'U':
return ParseHex(src.Slice(2), 8);
return ParseHex(ref src, 8);
case 't': case 't':
return ('\t', src.Slice(2));
return '\t';
case 'r': case 'r':
return ('\r', src.Slice(2));
return '\r';
case 'n': case 'n':
return ('\n', src.Slice(2));
return '\n';
case '\\': case '\\':
case '"': case '"':
case '[': case '[':
case ']': case ']':
return (span[1], src.Slice(2));
return src[1];
default: default:
throw new Exception("Unknown escape at " + src.ToString());
throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
} }
} }
else if (!span.IsEmpty)
else if (!src.IsEmpty)
{ {
return DecodeUTF8(src);
return DecodeUTF8(ref src);
} }


throw new Exception("Unexpected end of input"); throw new Exception("Unexpected end of input");
} }


public ReadOnlySpan<byte> ParseSequence( public ReadOnlySpan<byte> ParseSequence(
ref ParseState state,
ReadOnlyMemory<byte> src,
ParseState state,
ReadOnlySpan<byte> pos,
string ruleName, string ruleName,
List<LLamaGrammarElement> outElements, List<LLamaGrammarElement> outElements,
bool isNested) bool isNested)
{ {
int lastSymStart = outElements.Count; int lastSymStart = outElements.Count;
var pos = src.Span;


while (!pos.IsEmpty) while (!pos.IsEmpty)
{ {
@@ -208,9 +203,8 @@ namespace LLama.Grammar


while (pos[0] != '"') while (pos[0] != '"')
{ {
var charPair = ParseChar(src);
pos = charPair.Item2.Span;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 });
var charPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair });
} }
pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);
} }
@@ -229,17 +223,16 @@ namespace LLama.Grammar


while (pos[0] != ']') while (pos[0] != ']')
{ {
var charPair = ParseChar(src);
pos = charPair.Item2.Span;
var charPair = ParseChar(ref pos);
var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;


outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 });
outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair });


if (pos[0] == '-' && pos[1] != ']') if (pos[0] == '-' && pos[1] != ']')
{ {
var endCharPair = ParseChar(src.Slice(1));
pos = endCharPair.Item2.Span;
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 });
pos = pos.Slice(1);
var endCharPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair });
} }
} }
pos = ParseSpace(pos.Slice(1), isNested); pos = ParseSpace(pos.Slice(1), isNested);
@@ -321,9 +314,27 @@ namespace LLama.Grammar
return pos; return pos;
} }


public ReadOnlySpan<byte> ParseAlternates(ParseState state, ReadOnlySpan<byte> pos, string ruleName, uint subRuleId, bool v)
public ReadOnlySpan<byte> ParseAlternates(
ParseState state,
ReadOnlySpan<byte> src,
string ruleName,
uint ruleId,
bool isNested)
{ {
throw new NotImplementedException();
var rule = new List<LLamaGrammarElement>();
ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested);

while (pos[0] == '|')
{
rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
pos = ParseSpace(pos.Slice(1), true);
pos = ParseSequence(state, pos, ruleName, rule, isNested);
}

rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
AddRule(state, ruleId, rule);

return pos;
} }
} }
} }

Loading…
Cancel
Save