Browse Source

update test code for testing matching loader and pipe

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
1dfbc0aeff
2 changed files with 18 additions and 8 deletions
  1. +8
    -6
      test/io/loader/test_matching_loader.py
  2. +10
    -2
      test/io/pipe/test_matching.py

+ 8
- 6
test/io/loader/test_matching_loader.py View File

@@ -1,14 +1,13 @@


import unittest import unittest


from fastNLP.io import DataBundle
from fastNLP.io.loader.matching import RTELoader
from fastNLP.io.loader.matching import QNLILoader
from fastNLP.io.loader.matching import SNLILoader
from fastNLP.io.loader.matching import QuoraLoader
from fastNLP.io.loader.matching import MNLILoader
import os import os


from fastNLP.io import DataBundle
from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \
BQCorpusLoader, XNLILoader, LCQMCLoader


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestMatchingDownload(unittest.TestCase): class TestMatchingDownload(unittest.TestCase):
def test_download(self): def test_download(self):
@@ -30,6 +29,9 @@ class TestMatchingLoad(unittest.TestCase):
'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False), 'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False),
'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True), 'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True),
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False),
} }
for k, v in data_set_dict.items(): for k, v in data_set_dict.items():
path, loader, instance, warns = v path, loader, instance, warns = v


+ 10
- 2
test/io/pipe/test_matching.py View File

@@ -3,8 +3,10 @@ import unittest
import os import os


from fastNLP.io import DataBundle from fastNLP.io import DataBundle
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe, \
XNLIPipe, BQCorpusPipe, LCQMCPipe
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe, \
XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe




@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
@@ -35,6 +37,9 @@ class TestRunMatchingPipe(unittest.TestCase):
'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), 'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False),
} }
for k, v in data_set_dict.items(): for k, v in data_set_dict.items():
path, pipe1, pipe2, data_set, vocab, warns = v path, pipe1, pipe2, data_set, vocab, warns = v
@@ -48,6 +53,9 @@ class TestRunMatchingPipe(unittest.TestCase):


self.assertTrue(isinstance(data_bundle1, DataBundle)) self.assertTrue(isinstance(data_bundle1, DataBundle))
self.assertEqual(len(data_set), data_bundle1.num_dataset) self.assertEqual(len(data_set), data_bundle1.num_dataset)
print(k)
print(data_bundle1)
print(data_bundle2)
for x, y in zip(data_set, data_bundle1.iter_datasets()): for x, y in zip(data_set, data_bundle1.iter_datasets()):
name, dataset = y name, dataset = y
self.assertEqual(x, len(dataset)) self.assertEqual(x, len(dataset))


Loading…
Cancel
Save