Browse Source

update test code for testing matching loader and pipe

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
1dfbc0aeff
2 changed files with 18 additions and 8 deletions
  1. +8
    -6
      test/io/loader/test_matching_loader.py
  2. +10
    -2
      test/io/pipe/test_matching.py

+ 8
- 6
test/io/loader/test_matching_loader.py View File

@@ -1,14 +1,13 @@

import unittest

from fastNLP.io import DataBundle
from fastNLP.io.loader.matching import RTELoader
from fastNLP.io.loader.matching import QNLILoader
from fastNLP.io.loader.matching import SNLILoader
from fastNLP.io.loader.matching import QuoraLoader
from fastNLP.io.loader.matching import MNLILoader
import os

from fastNLP.io import DataBundle
from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \
BQCorpusLoader, XNLILoader, LCQMCLoader


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestMatchingDownload(unittest.TestCase):
def test_download(self):
@@ -30,6 +29,9 @@ class TestMatchingLoad(unittest.TestCase):
'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False),
'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True),
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False),
}
for k, v in data_set_dict.items():
path, loader, instance, warns = v


+ 10
- 2
test/io/pipe/test_matching.py View File

@@ -3,8 +3,10 @@ import unittest
import os

from fastNLP.io import DataBundle
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe, \
XNLIPipe, BQCorpusPipe, LCQMCPipe
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe, \
XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
@@ -35,6 +37,9 @@ class TestRunMatchingPipe(unittest.TestCase):
'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False),
}
for k, v in data_set_dict.items():
path, pipe1, pipe2, data_set, vocab, warns = v
@@ -48,6 +53,9 @@ class TestRunMatchingPipe(unittest.TestCase):

self.assertTrue(isinstance(data_bundle1, DataBundle))
self.assertEqual(len(data_set), data_bundle1.num_dataset)
print(k)
print(data_bundle1)
print(data_bundle2)
for x, y in zip(data_set, data_bundle1.iter_datasets()):
name, dataset = y
self.assertEqual(x, len(dataset))


Loading…
Cancel
Save