|
|
@@ -10,7 +10,7 @@ from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, Weibo |
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
class TestClassificationPipe(unittest.TestCase): |
|
|
|
def test_process_from_file(self): |
|
|
|
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: |
|
|
|
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: |
|
|
|
with self.subTest(pipe=pipe): |
|
|
|
print(pipe) |
|
|
|
data_bundle = pipe(tokenizer='raw').process_from_file() |
|
|
@@ -33,6 +33,7 @@ class TestCNClassificationPipe(unittest.TestCase): |
|
|
|
print(data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
class TestRunClassificationPipe(unittest.TestCase): |
|
|
|
def test_process_from_file(self): |
|
|
|
data_set_dict = { |
|
|
@@ -79,15 +80,14 @@ class TestRunClassificationPipe(unittest.TestCase): |
|
|
|
data_bundle = pipe(tokenizer='raw').process_from_file(path) |
|
|
|
else: |
|
|
|
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) |
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue(isinstance(data_bundle, DataBundle)) |
|
|
|
self.assertEqual(len(data_set), data_bundle.num_dataset) |
|
|
|
for name, dataset in data_bundle.iter_datasets(): |
|
|
|
self.assertTrue(name in data_set.keys()) |
|
|
|
self.assertEqual(data_set[name], len(dataset)) |
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(vocab), data_bundle.num_vocab) |
|
|
|
for name, vocabs in data_bundle.iter_vocabs(): |
|
|
|
self.assertTrue(name in vocab.keys()) |
|
|
|
self.assertEqual(vocab[name], len(vocabs)) |
|
|
|
|