From 1dfbc0aeff22dbc79c7c5035b6e68a174e1e79bc Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Tue, 17 Sep 2019 16:25:51 +0800 Subject: [PATCH] update test code for testing matching loader and pipe --- test/io/loader/test_matching_loader.py | 14 ++++++++------ test/io/pipe/test_matching.py | 12 ++++++++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index 8d6e182c..eb4ec2ba 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -1,14 +1,13 @@ import unittest -from fastNLP.io import DataBundle -from fastNLP.io.loader.matching import RTELoader -from fastNLP.io.loader.matching import QNLILoader -from fastNLP.io.loader.matching import SNLILoader -from fastNLP.io.loader.matching import QuoraLoader -from fastNLP.io.loader.matching import MNLILoader import os +from fastNLP.io import DataBundle +from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \ + BQCorpusLoader, XNLILoader, LCQMCLoader + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestMatchingDownload(unittest.TestCase): def test_download(self): @@ -30,6 +29,9 @@ class TestMatchingLoad(unittest.TestCase): 'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False), 'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), + 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), + 'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), } for k, v in data_set_dict.items(): path, loader, instance, warns = v diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index 8b0076c2..785d44bb 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -3,8 +3,10 @@ import unittest import os from fastNLP.io import DataBundle -from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe -from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe, \ + XNLIPipe, BQCorpusPipe, LCQMCPipe +from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe, \ + XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") @@ -35,6 +37,9 @@ class TestRunMatchingPipe(unittest.TestCase): 'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), + 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), + 'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), } for k, v in data_set_dict.items(): path, pipe1, pipe2, data_set, vocab, warns = v @@ -48,6 +53,9 @@ class TestRunMatchingPipe(unittest.TestCase): self.assertTrue(isinstance(data_bundle1, DataBundle)) self.assertEqual(len(data_set), data_bundle1.num_dataset) + print(k) + print(data_bundle1) + print(data_bundle2) for x, y in zip(data_set, data_bundle1.iter_datasets()): name, dataset = y self.assertEqual(x, len(dataset))