- Covered text transforms - Removed unnecessary non-async transformstags/v0.6.0
| @@ -1,4 +1,5 @@ | |||
| using LLama.Exceptions; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Native; | |||
| using LLama.Grammars; | |||
| @@ -211,6 +212,61 @@ namespace LLama.Unittest | |||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | |||
| } | |||
| [Fact] | |||
| public void ParseGrammarNotSequence() | |||
| { | |||
| var grammarBytes = @"root ::= [^a]"; | |||
| var expected = new List<KeyValuePair<string, uint>> | |||
| { | |||
| new KeyValuePair<string, uint>("root", 0), | |||
| }; | |||
| var expectedRules = new List<LLamaGrammarElement> | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_NOT, 97), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }; | |||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | |||
| } | |||
| [Fact] | |||
| public void ParseGrammarWithMultibyteCharacter() | |||
| { | |||
| var grammarBytes = @"root ::= [罗]*"; | |||
| var expected = new List<KeyValuePair<string, uint>> | |||
| { | |||
| new KeyValuePair<string, uint>("root", 0), | |||
| new KeyValuePair<string, uint>("root_1", 1), | |||
| }; | |||
| var expectedRules = new List<LLamaGrammarElement> | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32599), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }; | |||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | |||
| } | |||
| [Fact] | |||
| public void InvalidGrammarMissingRuleDefinition() | |||
| { | |||
| var parsedGrammar = new GBNFGrammarParser(); | |||
| var grammarBytes = @"root := [^a]"; | |||
| Assert.Throws<GrammarExpectedNext>(() => | |||
| { | |||
| parsedGrammar.Parse(grammarBytes, "root"); | |||
| }); | |||
| } | |||
| [Fact] | |||
| public void InvalidGrammarNoClosingBracket() | |||
| @@ -269,6 +325,37 @@ namespace LLama.Unittest | |||
| }); | |||
| } | |||
| [Fact] | |||
| public void InvalidGrammarBadEscapeCharacter() | |||
| { | |||
| var parsedGrammar = new GBNFGrammarParser(); | |||
| var grammarBytes = @" | |||
| root ::= (expr ""="" ws term ""\z"")+ ## <--- `\z` is not a valid escape character | |||
| expr ::= term ([-+*/] term)* | |||
| term ::= ident | num | ""("" ws expr "")"" ws | |||
| ident ::= [a-z] [a-z0-9_]* ws | |||
| num ::= [0-9]+ ws | |||
| ws ::= [ \t\n]* | |||
| "; | |||
| Assert.Throws<GrammarUnknownEscapeCharacter>(() => | |||
| { | |||
| parsedGrammar.Parse(grammarBytes, "root"); | |||
| }); | |||
| } | |||
| [Fact] | |||
| public void InvalidGrammarUnexpectedEndOfInput() | |||
| { | |||
| var parsedGrammar = new GBNFGrammarParser(); | |||
| var grammarBytes = @"root ::= (expr ""="" ws term ""\"; | |||
| Assert.Throws<GrammarUnexpectedEndOfInput>(() => | |||
| { | |||
| parsedGrammar.Parse(grammarBytes, "root"); | |||
| }); | |||
| } | |||
| [Fact] | |||
| public void InvalidRuleNoElements() | |||
| @@ -43,7 +43,7 @@ namespace LLama.Unittest | |||
| Assert.Equal(result1, result2); | |||
| } | |||
| [Fact] | |||
| [Fact(Skip = "Very very slow in CI")] | |||
| public async Task OutOfContext() | |||
| { | |||
| var executor = new StatelessExecutor(_weights, _params); | |||
| @@ -0,0 +1,27 @@ | |||
| namespace LLama.Unittest | |||
| { | |||
| public sealed class TextTransformTests | |||
| { | |||
| [Fact] | |||
| public void NaiveTextInputTransformTrimsText() | |||
| { | |||
| var transform = new LLamaTransforms.NaiveTextInputTransform(); | |||
| Assert.Equal("hello", transform.Transform("hello")); | |||
| Assert.Equal("hello", transform.Transform(" hello")); | |||
| Assert.Equal("hello", transform.Transform("hello ")); | |||
| Assert.Equal("hello", transform.Transform(" hello ")); | |||
| Assert.Equal("hello world", transform.Transform(" hello world ")); | |||
| } | |||
| [Fact] | |||
| public async Task EmptyTextOutputStreamTransformDoesNothing() | |||
| { | |||
| var input = new[] { "Hello", "world" }; | |||
| var transform = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| Assert.Equal(input, await transform.TransformAsync(input.ToAsyncEnumerable()).ToArrayAsync()); | |||
| } | |||
| } | |||
| } | |||
| @@ -155,35 +155,27 @@ namespace LLama.Grammars | |||
| { | |||
| if (src[0] == '\\') | |||
| { | |||
| if (src.Length < 2) | |||
| throw new GrammarUnexpectedEndOfInput(); | |||
| var chr = src[1]; | |||
| src = src.Slice(2); | |||
| switch (chr) | |||
| return (char)chr switch | |||
| { | |||
| case (byte)'x': | |||
| return ParseHex(ref src, 2); | |||
| case (byte)'u': | |||
| return ParseHex(ref src, 4); | |||
| case (byte)'U': | |||
| return ParseHex(ref src, 8); | |||
| case (byte)'t': | |||
| return '\t'; | |||
| case (byte)'r': | |||
| return '\r'; | |||
| case (byte)'n': | |||
| return '\n'; | |||
| case (byte)'\\': | |||
| case (byte)'"': | |||
| case (byte)'[': | |||
| case (byte)']': | |||
| return chr; | |||
| default: | |||
| throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())); | |||
| } | |||
| 'x' => ParseHex(ref src, 2), | |||
| 'u' => ParseHex(ref src, 4), | |||
| 'U' => ParseHex(ref src, 8), | |||
| 't' => '\t', | |||
| 'r' => '\r', | |||
| 'n' => '\n', | |||
| '\\' or '"' or '[' or ']' => chr, | |||
| _ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())), | |||
| }; | |||
| } | |||
| else if (!src.IsEmpty) | |||
| { | |||
| if (!src.IsEmpty) | |||
| return DecodeUTF8(ref src); | |||
| } | |||
| throw new GrammarUnexpectedEndOfInput(); | |||
| } | |||
| @@ -18,16 +18,17 @@ namespace LLama | |||
| /// </summary> | |||
| public class DefaultHistoryTransform : IHistoryTransform | |||
| { | |||
| private readonly string defaultUserName = "User"; | |||
| private readonly string defaultAssistantName = "Assistant"; | |||
| private readonly string defaultSystemName = "System"; | |||
| private readonly string defaultUnknownName = "??"; | |||
| private const string defaultUserName = "User"; | |||
| private const string defaultAssistantName = "Assistant"; | |||
| private const string defaultSystemName = "System"; | |||
| private const string defaultUnknownName = "??"; | |||
| private readonly string _userName; | |||
| private readonly string _assistantName; | |||
| private readonly string _systemName; | |||
| private readonly string _unknownName; | |||
| private readonly bool _isInstructMode; | |||
| string _userName; | |||
| string _assistantName; | |||
| string _systemName; | |||
| string _unknownName; | |||
| bool _isInstructMode; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -107,46 +108,37 @@ namespace LLama | |||
| /// <summary> | |||
| /// A text input transform that only trims the text. | |||
| /// </summary> | |||
| public class NaiveTextInputTransform : ITextTransform | |||
| public class NaiveTextInputTransform | |||
| : ITextTransform | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| public NaiveTextInputTransform() | |||
| { | |||
| } | |||
| /// <inheritdoc /> | |||
| public string Transform(string text) | |||
| { | |||
| return text.Trim(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// A no-op text input transform. | |||
| /// </summary> | |||
| public class EmptyTextOutputStreamTransform : ITextStreamTransform | |||
| public class EmptyTextOutputStreamTransform | |||
| : ITextStreamTransform | |||
| { | |||
| /// <inheritdoc /> | |||
| public IEnumerable<string> Transform(IEnumerable<string> tokens) | |||
| { | |||
| return tokens; | |||
| } | |||
| /// <inheritdoc /> | |||
| public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | |||
| { | |||
| return tokens; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// A text output transform that removes the keywords from the response. | |||
| /// </summary> | |||
| public class KeywordTextOutputStreamTransform : ITextStreamTransform | |||
| { | |||
| HashSet<string> _keywords; | |||
| int _maxKeywordLength; | |||
| bool _removeAllMatchedTokens; | |||
| private readonly HashSet<string> _keywords; | |||
| private readonly int _maxKeywordLength; | |||
| private readonly bool _removeAllMatchedTokens; | |||
| /// <summary> | |||
| /// | |||
| @@ -164,59 +156,7 @@ namespace LLama | |||
| _maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength; | |||
| _removeAllMatchedTokens = removeAllMatchedTokens; | |||
| } | |||
| /// <inheritdoc /> | |||
| public IEnumerable<string> Transform(IEnumerable<string> tokens) | |||
| { | |||
| var window = new Queue<string>(); | |||
| foreach (var s in tokens) | |||
| { | |||
| window.Enqueue(s); | |||
| var current = string.Join("", window); | |||
| if (_keywords.Any(x => current.Contains(x))) | |||
| { | |||
| var matchedKeyword = _keywords.First(x => current.Contains(x)); | |||
| int total = window.Count; | |||
| for (int i = 0; i < total; i++) | |||
| { | |||
| window.Dequeue(); | |||
| } | |||
| if (!_removeAllMatchedTokens) | |||
| { | |||
| yield return current.Replace(matchedKeyword, ""); | |||
| } | |||
| } | |||
| if (current.Length >= _maxKeywordLength) | |||
| { | |||
| if (_keywords.Any(x => current.Contains(x))) | |||
| { | |||
| var matchedKeyword = _keywords.First(x => current.Contains(x)); | |||
| int total = window.Count; | |||
| for (int i = 0; i < total; i++) | |||
| { | |||
| window.Dequeue(); | |||
| } | |||
| if (!_removeAllMatchedTokens) | |||
| { | |||
| yield return current.Replace(matchedKeyword, ""); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| int total = window.Count; | |||
| for (int i = 0; i < total; i++) | |||
| { | |||
| yield return window.Dequeue(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| int totalCount = window.Count; | |||
| for (int i = 0; i < totalCount; i++) | |||
| { | |||
| yield return window.Dequeue(); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | |||
| { | |||