Browse Source

- Improved coverage of `GBNFGrammarParser` up to 96%

- Covered text transforms
 - Removed unnecessary non-async transforms
tags/v0.6.0
Martin Evans 2 years ago
parent
commit
45118520fa
5 changed files with 151 additions and 105 deletions
  1. +88
    -1
      LLama.Unittest/GrammarParserTest.cs
  2. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  3. +27
    -0
      LLama.Unittest/TextTransformTests.cs
  4. +16
    -24
      LLama/Grammars/GBNFGrammarParser.cs
  5. +19
    -79
      LLama/LLamaTransforms.cs

+ 88
- 1
LLama.Unittest/GrammarParserTest.cs View File

@@ -1,4 +1,5 @@
using LLama.Exceptions;
using System.Text;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;

@@ -211,6 +212,61 @@ namespace LLama.Unittest
CheckGrammar(grammarBytes, "root", expected, expectedRules);
}

[Fact]
public void ParseGrammarNotSequence()
{
var grammarBytes = @"root ::= [^a]";

var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("root", 0),
};

var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_NOT, 97),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

CheckGrammar(grammarBytes, "root", expected, expectedRules);
}

[Fact]
public void ParseGrammarWithMultibyteCharacter()
{
var grammarBytes = @"root ::= [罗]*";

var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
};

var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32599),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

CheckGrammar(grammarBytes, "root", expected, expectedRules);
}


[Fact]
public void InvalidGrammarMissingRuleDefinition()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"root := [^a]";

Assert.Throws<GrammarExpectedNext>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}

[Fact]
public void InvalidGrammarNoClosingBracket()
@@ -269,6 +325,37 @@ namespace LLama.Unittest
});
}

[Fact]
public void InvalidGrammarBadEscapeCharacter()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"
root ::= (expr ""="" ws term ""\z"")+ ## <--- `\z` is not a valid escape character
expr ::= term ([-+*/] term)*
term ::= ident | num | ""("" ws expr "")"" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
";

Assert.Throws<GrammarUnknownEscapeCharacter>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}

[Fact]
public void InvalidGrammarUnexpectedEndOfInput()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"root ::= (expr ""="" ws term ""\";

Assert.Throws<GrammarUnexpectedEndOfInput>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}


[Fact]
public void InvalidRuleNoElements()


+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -43,7 +43,7 @@ namespace LLama.Unittest
Assert.Equal(result1, result2);
}

[Fact]
[Fact(Skip = "Very very slow in CI")]
public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);


+ 27
- 0
LLama.Unittest/TextTransformTests.cs View File

@@ -0,0 +1,27 @@
namespace LLama.Unittest
{
public sealed class TextTransformTests
{
[Fact]
public void NaiveTextInputTransformTrimsText()
{
var transform = new LLamaTransforms.NaiveTextInputTransform();

Assert.Equal("hello", transform.Transform("hello"));
Assert.Equal("hello", transform.Transform(" hello"));
Assert.Equal("hello", transform.Transform("hello "));
Assert.Equal("hello", transform.Transform(" hello "));
Assert.Equal("hello world", transform.Transform(" hello world "));
}

[Fact]
public async Task EmptyTextOutputStreamTransformDoesNothing()
{
var input = new[] { "Hello", "world" };

var transform = new LLamaTransforms.EmptyTextOutputStreamTransform();

Assert.Equal(input, await transform.TransformAsync(input.ToAsyncEnumerable()).ToArrayAsync());
}
}
}

+ 16
- 24
LLama/Grammars/GBNFGrammarParser.cs View File

@@ -155,35 +155,27 @@ namespace LLama.Grammars
{
if (src[0] == '\\')
{
if (src.Length < 2)
throw new GrammarUnexpectedEndOfInput();

var chr = src[1];
src = src.Slice(2);
switch (chr)

return (char)chr switch
{
case (byte)'x':
return ParseHex(ref src, 2);
case (byte)'u':
return ParseHex(ref src, 4);
case (byte)'U':
return ParseHex(ref src, 8);
case (byte)'t':
return '\t';
case (byte)'r':
return '\r';
case (byte)'n':
return '\n';
case (byte)'\\':
case (byte)'"':
case (byte)'[':
case (byte)']':
return chr;
default:
throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray()));
}
'x' => ParseHex(ref src, 2),
'u' => ParseHex(ref src, 4),
'U' => ParseHex(ref src, 8),
't' => '\t',
'r' => '\r',
'n' => '\n',
'\\' or '"' or '[' or ']' => chr,
_ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())),
};
}
else if (!src.IsEmpty)
{

if (!src.IsEmpty)
return DecodeUTF8(ref src);
}

throw new GrammarUnexpectedEndOfInput();
}


+ 19
- 79
LLama/LLamaTransforms.cs View File

@@ -18,16 +18,17 @@ namespace LLama
/// </summary>
public class DefaultHistoryTransform : IHistoryTransform
{
private readonly string defaultUserName = "User";
private readonly string defaultAssistantName = "Assistant";
private readonly string defaultSystemName = "System";
private readonly string defaultUnknownName = "??";
private const string defaultUserName = "User";
private const string defaultAssistantName = "Assistant";
private const string defaultSystemName = "System";
private const string defaultUnknownName = "??";

private readonly string _userName;
private readonly string _assistantName;
private readonly string _systemName;
private readonly string _unknownName;
private readonly bool _isInstructMode;

string _userName;
string _assistantName;
string _systemName;
string _unknownName;
bool _isInstructMode;
/// <summary>
///
/// </summary>
@@ -107,46 +108,37 @@ namespace LLama
/// <summary>
/// A text input transform that only trims the text.
/// </summary>
public class NaiveTextInputTransform : ITextTransform
public class NaiveTextInputTransform
: ITextTransform
{
/// <summary>
///
/// </summary>
public NaiveTextInputTransform()
{
}
/// <inheritdoc />
public string Transform(string text)
{
return text.Trim();
}
}

/// <summary>
/// A no-op text input transform.
/// </summary>
public class EmptyTextOutputStreamTransform : ITextStreamTransform
public class EmptyTextOutputStreamTransform
: ITextStreamTransform
{
/// <inheritdoc />
public IEnumerable<string> Transform(IEnumerable<string> tokens)
{
return tokens;
}

/// <inheritdoc />
public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
{
return tokens;
}
}

/// <summary>
/// A text output transform that removes the keywords from the response.
/// </summary>
public class KeywordTextOutputStreamTransform : ITextStreamTransform
{
HashSet<string> _keywords;
int _maxKeywordLength;
bool _removeAllMatchedTokens;
private readonly HashSet<string> _keywords;
private readonly int _maxKeywordLength;
private readonly bool _removeAllMatchedTokens;

/// <summary>
///
@@ -164,59 +156,7 @@ namespace LLama
_maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength;
_removeAllMatchedTokens = removeAllMatchedTokens;
}
/// <inheritdoc />
public IEnumerable<string> Transform(IEnumerable<string> tokens)
{
var window = new Queue<string>();

foreach (var s in tokens)
{
window.Enqueue(s);
var current = string.Join("", window);
if (_keywords.Any(x => current.Contains(x)))
{
var matchedKeyword = _keywords.First(x => current.Contains(x));
int total = window.Count;
for (int i = 0; i < total; i++)
{
window.Dequeue();
}
if (!_removeAllMatchedTokens)
{
yield return current.Replace(matchedKeyword, "");
}
}
if (current.Length >= _maxKeywordLength)
{
if (_keywords.Any(x => current.Contains(x)))
{
var matchedKeyword = _keywords.First(x => current.Contains(x));
int total = window.Count;
for (int i = 0; i < total; i++)
{
window.Dequeue();
}
if (!_removeAllMatchedTokens)
{
yield return current.Replace(matchedKeyword, "");
}
}
else
{
int total = window.Count;
for (int i = 0; i < total; i++)
{
yield return window.Dequeue();
}
}
}
}
int totalCount = window.Count;
for (int i = 0; i < totalCount; i++)
{
yield return window.Dequeue();
}
}
/// <inheritdoc />
public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
{


Loading…
Cancel
Save