|
|
@@ -40,15 +40,19 @@ class TestRunClassificationPipe(unittest.TestCase): |
|
|
|
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True), |
|
|
|
'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False), |
|
|
|
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False), |
|
|
|
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False), |
|
|
|
} |
|
|
|
for k, v in data_set_dict.items(): |
|
|
|
path, pipe, data_set, vocab, warns = v |
|
|
|
with self.subTest(pipe=pipe): |
|
|
|
if warns: |
|
|
|
with self.assertWarns(Warning): |
|
|
|
if 'Chn' not in k: |
|
|
|
if warns: |
|
|
|
with self.assertWarns(Warning): |
|
|
|
data_bundle = pipe(tokenizer='raw').process_from_file(path) |
|
|
|
else: |
|
|
|
data_bundle = pipe(tokenizer='raw').process_from_file(path) |
|
|
|
else: |
|
|
|
data_bundle = pipe(tokenizer='raw').process_from_file(path) |
|
|
|
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) |
|
|
|
|
|
|
|
self.assertTrue(isinstance(data_bundle, DataBundle)) |
|
|
|
self.assertEqual(len(data_set), data_bundle.num_dataset) |
|
|
|