|
|
@@ -3,8 +3,10 @@ import unittest |
|
|
|
import os |
|
|
|
|
|
|
|
from fastNLP.io import DataBundle |
|
|
|
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe |
|
|
|
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe |
|
|
|
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe, \ |
|
|
|
XNLIPipe, BQCorpusPipe, LCQMCPipe |
|
|
|
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe, \ |
|
|
|
XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
@@ -35,6 +37,9 @@ class TestRunMatchingPipe(unittest.TestCase): |
|
|
|
'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), |
|
|
|
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), |
|
|
|
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), |
|
|
|
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), |
|
|
|
'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False), |
|
|
|
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), |
|
|
|
} |
|
|
|
for k, v in data_set_dict.items(): |
|
|
|
path, pipe1, pipe2, data_set, vocab, warns = v |
|
|
@@ -48,6 +53,9 @@ class TestRunMatchingPipe(unittest.TestCase): |
|
|
|
|
|
|
|
self.assertTrue(isinstance(data_bundle1, DataBundle)) |
|
|
|
self.assertEqual(len(data_set), data_bundle1.num_dataset) |
|
|
|
print(k) |
|
|
|
print(data_bundle1) |
|
|
|
print(data_bundle2) |
|
|
|
for x, y in zip(data_set, data_bundle1.iter_datasets()): |
|
|
|
name, dataset = y |
|
|
|
self.assertEqual(x, len(dataset)) |
|
|
|