- Covered text transforms - Removed unnecessary non-async transformstags/v0.6.0
| @@ -1,4 +1,5 @@ | |||||
| using LLama.Exceptions; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| using LLama.Native; | using LLama.Native; | ||||
| using LLama.Grammars; | using LLama.Grammars; | ||||
| @@ -211,6 +212,61 @@ namespace LLama.Unittest | |||||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | CheckGrammar(grammarBytes, "root", expected, expectedRules); | ||||
| } | } | ||||
| [Fact] | |||||
| public void ParseGrammarNotSequence() | |||||
| { | |||||
| var grammarBytes = @"root ::= [^a]"; | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | |||||
| new KeyValuePair<string, uint>("root", 0), | |||||
| }; | |||||
| var expectedRules = new List<LLamaGrammarElement> | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_NOT, 97), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }; | |||||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | |||||
| } | |||||
| [Fact] | |||||
| public void ParseGrammarWithMultibyteCharacter() | |||||
| { | |||||
| var grammarBytes = @"root ::= [罗]*"; | |||||
| var expected = new List<KeyValuePair<string, uint>> | |||||
| { | |||||
| new KeyValuePair<string, uint>("root", 0), | |||||
| new KeyValuePair<string, uint>("root_1", 1), | |||||
| }; | |||||
| var expectedRules = new List<LLamaGrammarElement> | |||||
| { | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32599), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | |||||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||||
| }; | |||||
| CheckGrammar(grammarBytes, "root", expected, expectedRules); | |||||
| } | |||||
| [Fact] | |||||
| public void InvalidGrammarMissingRuleDefinition() | |||||
| { | |||||
| var parsedGrammar = new GBNFGrammarParser(); | |||||
| var grammarBytes = @"root := [^a]"; | |||||
| Assert.Throws<GrammarExpectedNext>(() => | |||||
| { | |||||
| parsedGrammar.Parse(grammarBytes, "root"); | |||||
| }); | |||||
| } | |||||
| [Fact] | [Fact] | ||||
| public void InvalidGrammarNoClosingBracket() | public void InvalidGrammarNoClosingBracket() | ||||
| @@ -269,6 +325,37 @@ namespace LLama.Unittest | |||||
| }); | }); | ||||
| } | } | ||||
| [Fact] | |||||
| public void InvalidGrammarBadEscapeCharacter() | |||||
| { | |||||
| var parsedGrammar = new GBNFGrammarParser(); | |||||
| var grammarBytes = @" | |||||
| root ::= (expr ""="" ws term ""\z"")+ ## <--- `\z` is not a valid escape character | |||||
| expr ::= term ([-+*/] term)* | |||||
| term ::= ident | num | ""("" ws expr "")"" ws | |||||
| ident ::= [a-z] [a-z0-9_]* ws | |||||
| num ::= [0-9]+ ws | |||||
| ws ::= [ \t\n]* | |||||
| "; | |||||
| Assert.Throws<GrammarUnknownEscapeCharacter>(() => | |||||
| { | |||||
| parsedGrammar.Parse(grammarBytes, "root"); | |||||
| }); | |||||
| } | |||||
| [Fact] | |||||
| public void InvalidGrammarUnexpectedEndOfInput() | |||||
| { | |||||
| var parsedGrammar = new GBNFGrammarParser(); | |||||
| var grammarBytes = @"root ::= (expr ""="" ws term ""\"; | |||||
| Assert.Throws<GrammarUnexpectedEndOfInput>(() => | |||||
| { | |||||
| parsedGrammar.Parse(grammarBytes, "root"); | |||||
| }); | |||||
| } | |||||
| [Fact] | [Fact] | ||||
| public void InvalidRuleNoElements() | public void InvalidRuleNoElements() | ||||
| @@ -43,7 +43,7 @@ namespace LLama.Unittest | |||||
| Assert.Equal(result1, result2); | Assert.Equal(result1, result2); | ||||
| } | } | ||||
| [Fact] | |||||
| [Fact(Skip = "Very very slow in CI")] | |||||
| public async Task OutOfContext() | public async Task OutOfContext() | ||||
| { | { | ||||
| var executor = new StatelessExecutor(_weights, _params); | var executor = new StatelessExecutor(_weights, _params); | ||||
| @@ -0,0 +1,27 @@ | |||||
| namespace LLama.Unittest | |||||
| { | |||||
| public sealed class TextTransformTests | |||||
| { | |||||
| [Fact] | |||||
| public void NaiveTextInputTransformTrimsText() | |||||
| { | |||||
| var transform = new LLamaTransforms.NaiveTextInputTransform(); | |||||
| Assert.Equal("hello", transform.Transform("hello")); | |||||
| Assert.Equal("hello", transform.Transform(" hello")); | |||||
| Assert.Equal("hello", transform.Transform("hello ")); | |||||
| Assert.Equal("hello", transform.Transform(" hello ")); | |||||
| Assert.Equal("hello world", transform.Transform(" hello world ")); | |||||
| } | |||||
| [Fact] | |||||
| public async Task EmptyTextOutputStreamTransformDoesNothing() | |||||
| { | |||||
| var input = new[] { "Hello", "world" }; | |||||
| var transform = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||||
| Assert.Equal(input, await transform.TransformAsync(input.ToAsyncEnumerable()).ToArrayAsync()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -155,35 +155,27 @@ namespace LLama.Grammars | |||||
| { | { | ||||
| if (src[0] == '\\') | if (src[0] == '\\') | ||||
| { | { | ||||
| if (src.Length < 2) | |||||
| throw new GrammarUnexpectedEndOfInput(); | |||||
| var chr = src[1]; | var chr = src[1]; | ||||
| src = src.Slice(2); | src = src.Slice(2); | ||||
| switch (chr) | |||||
| return (char)chr switch | |||||
| { | { | ||||
| case (byte)'x': | |||||
| return ParseHex(ref src, 2); | |||||
| case (byte)'u': | |||||
| return ParseHex(ref src, 4); | |||||
| case (byte)'U': | |||||
| return ParseHex(ref src, 8); | |||||
| case (byte)'t': | |||||
| return '\t'; | |||||
| case (byte)'r': | |||||
| return '\r'; | |||||
| case (byte)'n': | |||||
| return '\n'; | |||||
| case (byte)'\\': | |||||
| case (byte)'"': | |||||
| case (byte)'[': | |||||
| case (byte)']': | |||||
| return chr; | |||||
| default: | |||||
| throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())); | |||||
| } | |||||
| 'x' => ParseHex(ref src, 2), | |||||
| 'u' => ParseHex(ref src, 4), | |||||
| 'U' => ParseHex(ref src, 8), | |||||
| 't' => '\t', | |||||
| 'r' => '\r', | |||||
| 'n' => '\n', | |||||
| '\\' or '"' or '[' or ']' => chr, | |||||
| _ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())), | |||||
| }; | |||||
| } | } | ||||
| else if (!src.IsEmpty) | |||||
| { | |||||
| if (!src.IsEmpty) | |||||
| return DecodeUTF8(ref src); | return DecodeUTF8(ref src); | ||||
| } | |||||
| throw new GrammarUnexpectedEndOfInput(); | throw new GrammarUnexpectedEndOfInput(); | ||||
| } | } | ||||
| @@ -18,16 +18,17 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| public class DefaultHistoryTransform : IHistoryTransform | public class DefaultHistoryTransform : IHistoryTransform | ||||
| { | { | ||||
| private readonly string defaultUserName = "User"; | |||||
| private readonly string defaultAssistantName = "Assistant"; | |||||
| private readonly string defaultSystemName = "System"; | |||||
| private readonly string defaultUnknownName = "??"; | |||||
| private const string defaultUserName = "User"; | |||||
| private const string defaultAssistantName = "Assistant"; | |||||
| private const string defaultSystemName = "System"; | |||||
| private const string defaultUnknownName = "??"; | |||||
| private readonly string _userName; | |||||
| private readonly string _assistantName; | |||||
| private readonly string _systemName; | |||||
| private readonly string _unknownName; | |||||
| private readonly bool _isInstructMode; | |||||
| string _userName; | |||||
| string _assistantName; | |||||
| string _systemName; | |||||
| string _unknownName; | |||||
| bool _isInstructMode; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -107,46 +108,37 @@ namespace LLama | |||||
| /// <summary> | /// <summary> | ||||
| /// A text input transform that only trims the text. | /// A text input transform that only trims the text. | ||||
| /// </summary> | /// </summary> | ||||
| public class NaiveTextInputTransform : ITextTransform | |||||
| public class NaiveTextInputTransform | |||||
| : ITextTransform | |||||
| { | { | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| public NaiveTextInputTransform() | |||||
| { | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public string Transform(string text) | public string Transform(string text) | ||||
| { | { | ||||
| return text.Trim(); | return text.Trim(); | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// A no-op text input transform. | /// A no-op text input transform. | ||||
| /// </summary> | /// </summary> | ||||
| public class EmptyTextOutputStreamTransform : ITextStreamTransform | |||||
| public class EmptyTextOutputStreamTransform | |||||
| : ITextStreamTransform | |||||
| { | { | ||||
| /// <inheritdoc /> | |||||
| public IEnumerable<string> Transform(IEnumerable<string> tokens) | |||||
| { | |||||
| return tokens; | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | public IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | ||||
| { | { | ||||
| return tokens; | return tokens; | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// A text output transform that removes the keywords from the response. | /// A text output transform that removes the keywords from the response. | ||||
| /// </summary> | /// </summary> | ||||
| public class KeywordTextOutputStreamTransform : ITextStreamTransform | public class KeywordTextOutputStreamTransform : ITextStreamTransform | ||||
| { | { | ||||
| HashSet<string> _keywords; | |||||
| int _maxKeywordLength; | |||||
| bool _removeAllMatchedTokens; | |||||
| private readonly HashSet<string> _keywords; | |||||
| private readonly int _maxKeywordLength; | |||||
| private readonly bool _removeAllMatchedTokens; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -164,59 +156,7 @@ namespace LLama | |||||
| _maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength; | _maxKeywordLength = _keywords.Select(x => x.Length).Max() + redundancyLength; | ||||
| _removeAllMatchedTokens = removeAllMatchedTokens; | _removeAllMatchedTokens = removeAllMatchedTokens; | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public IEnumerable<string> Transform(IEnumerable<string> tokens) | |||||
| { | |||||
| var window = new Queue<string>(); | |||||
| foreach (var s in tokens) | |||||
| { | |||||
| window.Enqueue(s); | |||||
| var current = string.Join("", window); | |||||
| if (_keywords.Any(x => current.Contains(x))) | |||||
| { | |||||
| var matchedKeyword = _keywords.First(x => current.Contains(x)); | |||||
| int total = window.Count; | |||||
| for (int i = 0; i < total; i++) | |||||
| { | |||||
| window.Dequeue(); | |||||
| } | |||||
| if (!_removeAllMatchedTokens) | |||||
| { | |||||
| yield return current.Replace(matchedKeyword, ""); | |||||
| } | |||||
| } | |||||
| if (current.Length >= _maxKeywordLength) | |||||
| { | |||||
| if (_keywords.Any(x => current.Contains(x))) | |||||
| { | |||||
| var matchedKeyword = _keywords.First(x => current.Contains(x)); | |||||
| int total = window.Count; | |||||
| for (int i = 0; i < total; i++) | |||||
| { | |||||
| window.Dequeue(); | |||||
| } | |||||
| if (!_removeAllMatchedTokens) | |||||
| { | |||||
| yield return current.Replace(matchedKeyword, ""); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| int total = window.Count; | |||||
| for (int i = 0; i < total; i++) | |||||
| { | |||||
| yield return window.Dequeue(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int totalCount = window.Count; | |||||
| for (int i = 0; i < totalCount; i++) | |||||
| { | |||||
| yield return window.Dequeue(); | |||||
| } | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | ||||
| { | { | ||||