Browse Source

1. SKIP test_process_from_file

2. doc_utils 增加了 __all__ 的检查
tags/v0.6.0
ChenXin 3 years ago
parent
commit
ca25baf6b9
2 changed files with 11 additions and 7 deletions
  1. +7
    -3
      fastNLP/doc_utils.py
  2. +4
    -4
      test/io/pipe/test_classification.py

+ 7
- 3
fastNLP/doc_utils.py View File

@@ -23,7 +23,9 @@ def doc_process(m):
while 1:
defined_m = sys.modules[module_name]
try:
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
if not hasattr(defined_m, "__all__"):
print("Warning: Module {} lacks `__all__`".format(module_name))
elif "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \
+ " :class:`" + module_name + "." + name + "`\n" + obj.__doc__
break
@@ -34,7 +36,7 @@ def doc_process(m):
except:
print("Warning: Module {} lacks `__doc__`".format(module_name))
break
# 识别并标注基类,只有基类也在 fastNLP 中定义才显示
if inspect.isclass(obj):
@@ -45,7 +47,9 @@ def doc_process(m):
for i in range(len(parts) - 1):
defined_m = sys.modules[module_name]
try:
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
if not hasattr(defined_m, "__all__"):
print("Warning: Module {} lacks `__all__`".format(module_name))
elif "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__
break
module_name += "." + parts[i + 1]


+ 4
- 4
test/io/pipe/test_classification.py View File

@@ -10,7 +10,7 @@ from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, Weibo
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe(tokenizer='raw').process_from_file()
@@ -33,6 +33,7 @@ class TestCNClassificationPipe(unittest.TestCase):
print(data_bundle)


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestRunClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
data_set_dict = {
@@ -79,15 +80,14 @@ class TestRunClassificationPipe(unittest.TestCase):
data_bundle = pipe(tokenizer='raw').process_from_file(path)
else:
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
self.assertTrue(isinstance(data_bundle, DataBundle))
self.assertEqual(len(data_set), data_bundle.num_dataset)
for name, dataset in data_bundle.iter_datasets():
self.assertTrue(name in data_set.keys())
self.assertEqual(data_set[name], len(dataset))
self.assertEqual(len(vocab), data_bundle.num_vocab)
for name, vocabs in data_bundle.iter_vocabs():
self.assertTrue(name in vocab.keys())
self.assertEqual(vocab[name], len(vocabs))


Loading…
Cancel
Save