From 53975c045a6841e38d4a7cfcc23abea6de0fe3f3 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 26 Aug 2019 14:58:36 +0800 Subject: [PATCH] update the doc-tool & fix an importing bug --- docs/count.py | 42 ++++++++++++++++++++++++++++++++++ fastNLP/modules/decoder/crf.py | 2 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/docs/count.py b/docs/count.py index e1aad115..72868403 100644 --- a/docs/count.py +++ b/docs/count.py @@ -1,7 +1,28 @@ +import inspect import os import sys +def _colored_string(string: str, color: str or int) -> str: + """在终端中显示一串有颜色的文字 + :param string: 在终端中显示的文字 + :param color: 文字的颜色 + :return: + """ + if isinstance(color, str): + color = { + "black": 30, "Black": 30, "BLACK": 30, + "red": 31, "Red": 31, "RED": 31, + "green": 32, "Green": 32, "GREEN": 32, + "yellow": 33, "Yellow": 33, "YELLOW": 33, + "blue": 34, "Blue": 34, "BLUE": 34, + "purple": 35, "Purple": 35, "PURPLE": 35, + "cyan": 36, "Cyan": 36, "CYAN": 36, + "white": 37, "White": 37, "WHITE": 37 + }[color] + return "\033[%dm%s\033[0m" % (color, string) + + def find_all_modules(): modules = {} children = {} @@ -55,10 +76,31 @@ def create_rst_file(modules, name, children): fout.write(" " + module + "\n") +def check_file(m, name): + for item, obj in inspect.getmembers(m): + if inspect.isclass(obj) and obj.__module__ == name: + print(obj) + if inspect.isfunction(obj) and obj.__module__ == name: + print("FUNC", obj) + + +def check_files(modules): + for name in sorted(modules.keys()): + if name == 'fastNLP.core.utils': + check_file(modules[name], name) + + def main(): + print(_colored_string('Getting modules...', "Blue")) modules, to_doc, children = find_all_modules() + print(_colored_string('Done!', "Green")) + print(_colored_string('Creating rst files...', "Blue")) for name in to_doc: create_rst_file(modules, name, children) + print(_colored_string('Done!', "Green")) + print(_colored_string('Checking all files...', "Blue")) + check_files(modules) + print(_colored_string('Done!', "Green")) if __name__ == "__main__": diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index b47d0162..f63d46e3 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -9,7 +9,7 @@ import torch from torch import nn from ..utils import initial_parameter -from ...core import Vocabulary +from ...core.vocabulary import Vocabulary def allowed_transitions(id2target, encoding_type='bio', include_start_end=False):