From 967e5e568389db8f98fa27c43c2c065470b307f3 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 22 Apr 2019 01:31:41 +0800 Subject: [PATCH] doc tools --- docs/Makefile | 4 + docs/source/conf.py | 6 +- docs/source/fastNLP.api.rst | 52 +++++-- docs/source/fastNLP.core.rst | 98 ++++++++---- docs/source/fastNLP.io.rst | 48 +++--- docs/source/fastNLP.models.rst | 96 ++++++++++-- docs/source/fastNLP.modules.aggregator.rst | 42 ++++-- docs/source/fastNLP.modules.decoder.rst | 24 ++- docs/source/fastNLP.modules.encoder.rst | 74 ++++++--- docs/source/fastNLP.modules.rst | 33 ++++- docs/source/fastNLP.rst | 13 +- fastNLP/api/__init__.py | 3 + fastNLP/api/api.py | 26 ++-- fastNLP/automl/enas_trainer.py | 15 +- fastNLP/core/dataset.py | 2 +- fastNLP/core/fieldarray.py | 16 +- fastNLP/core/instance.py | 15 +- fastNLP/core/losses.py | 4 +- fastNLP/core/metrics.py | 165 ++++++++++++--------- fastNLP/core/trainer.py | 139 ++++++++--------- fastNLP/core/utils.py | 9 +- fastNLP/models/char_language_model.py | 13 +- fastNLP/models/enas_trainer.py | 15 +- 23 files changed, 599 insertions(+), 313 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index e978dfe6..6a5c7375 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,6 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = +SPHINXAPIDOC = sphinx-apidoc SPHINXBUILD = sphinx-build SPHINXPROJ = fastNLP SOURCEDIR = source @@ -12,6 +13,9 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +apidoc: + @$(SPHINXAPIDOC) -f -o source ../fastNLP + .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new diff --git a/docs/source/conf.py b/docs/source/conf.py index e449a9f8..96f7f437 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,9 +23,9 @@ copyright = '2018, xpqiu' author = 'xpqiu' # The short X.Y version -version = '0.2' +version = '0.4' # The full version, including alpha/beta/rc tags -release = '0.2' +release = '0.4' # -- General configuration --------------------------------------------------- @@ -67,7 +67,7 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = [] +exclude_patterns = ['modules.rst'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' diff --git a/docs/source/fastNLP.api.rst b/docs/source/fastNLP.api.rst index eb9192da..ee2413fb 100644 --- a/docs/source/fastNLP.api.rst +++ b/docs/source/fastNLP.api.rst @@ -1,36 +1,62 @@ -fastNLP.api -============ +fastNLP.api package +=================== -fastNLP.api.api ----------------- +Submodules +---------- + +fastNLP.api.api module +---------------------- .. automodule:: fastNLP.api.api :members: + :undoc-members: + :show-inheritance: -fastNLP.api.converter ----------------------- +fastNLP.api.converter module +---------------------------- .. automodule:: fastNLP.api.converter :members: + :undoc-members: + :show-inheritance: -fastNLP.api.model\_zoo ------------------------ +fastNLP.api.examples module +--------------------------- -.. automodule:: fastNLP.api.model_zoo +.. automodule:: fastNLP.api.examples :members: + :undoc-members: + :show-inheritance: -fastNLP.api.pipeline ---------------------- +fastNLP.api.pipeline module +--------------------------- .. automodule:: fastNLP.api.pipeline :members: + :undoc-members: + :show-inheritance: -fastNLP.api.processor ----------------------- +fastNLP.api.processor module +---------------------------- .. automodule:: fastNLP.api.processor :members: + :undoc-members: + :show-inheritance: + +fastNLP.api.utils module +------------------------ + +.. automodule:: fastNLP.api.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.api :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst index b9f6c89f..79d26c76 100644 --- a/docs/source/fastNLP.core.rst +++ b/docs/source/fastNLP.core.rst @@ -1,84 +1,126 @@ -fastNLP.core -============= +fastNLP.core package +==================== -fastNLP.core.batch -------------------- +Submodules +---------- + +fastNLP.core.batch module +------------------------- .. automodule:: fastNLP.core.batch :members: + :undoc-members: + :show-inheritance: + +fastNLP.core.callback module +---------------------------- -fastNLP.core.dataset ---------------------- +.. automodule:: fastNLP.core.callback + :members: + :undoc-members: + :show-inheritance: + +fastNLP.core.dataset module +--------------------------- .. automodule:: fastNLP.core.dataset :members: + :undoc-members: + :show-inheritance: -fastNLP.core.fieldarray ------------------------- +fastNLP.core.fieldarray module +------------------------------ .. automodule:: fastNLP.core.fieldarray :members: + :undoc-members: + :show-inheritance: -fastNLP.core.instance ----------------------- +fastNLP.core.instance module +---------------------------- .. automodule:: fastNLP.core.instance :members: + :undoc-members: + :show-inheritance: -fastNLP.core.losses --------------------- +fastNLP.core.losses module +-------------------------- .. automodule:: fastNLP.core.losses :members: + :undoc-members: + :show-inheritance: -fastNLP.core.metrics ---------------------- +fastNLP.core.metrics module +--------------------------- .. automodule:: fastNLP.core.metrics :members: + :undoc-members: + :show-inheritance: -fastNLP.core.optimizer ------------------------ +fastNLP.core.optimizer module +----------------------------- .. automodule:: fastNLP.core.optimizer :members: + :undoc-members: + :show-inheritance: -fastNLP.core.predictor ------------------------ +fastNLP.core.predictor module +----------------------------- .. automodule:: fastNLP.core.predictor :members: + :undoc-members: + :show-inheritance: -fastNLP.core.sampler ---------------------- +fastNLP.core.sampler module +--------------------------- .. automodule:: fastNLP.core.sampler :members: + :undoc-members: + :show-inheritance: -fastNLP.core.tester --------------------- +fastNLP.core.tester module +-------------------------- .. automodule:: fastNLP.core.tester :members: + :undoc-members: + :show-inheritance: -fastNLP.core.trainer ---------------------- +fastNLP.core.trainer module +--------------------------- .. automodule:: fastNLP.core.trainer :members: + :undoc-members: + :show-inheritance: -fastNLP.core.utils -------------------- +fastNLP.core.utils module +------------------------- .. automodule:: fastNLP.core.utils :members: + :undoc-members: + :show-inheritance: -fastNLP.core.vocabulary ------------------------- +fastNLP.core.vocabulary module +------------------------------ .. automodule:: fastNLP.core.vocabulary :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.core :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index d91e0d1c..bb30c5e7 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -1,42 +1,54 @@ -fastNLP.io -=========== +fastNLP.io package +================== -fastNLP.io.base\_loader ------------------------- +Submodules +---------- + +fastNLP.io.base\_loader module +------------------------------ .. automodule:: fastNLP.io.base_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.config\_io ----------------------- +fastNLP.io.config\_io module +---------------------------- .. automodule:: fastNLP.io.config_io :members: + :undoc-members: + :show-inheritance: -fastNLP.io.dataset\_loader ---------------------------- +fastNLP.io.dataset\_loader module +--------------------------------- .. automodule:: fastNLP.io.dataset_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.embed\_loader -------------------------- +fastNLP.io.embed\_loader module +------------------------------- .. automodule:: fastNLP.io.embed_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.logger ------------------- - -.. automodule:: fastNLP.io.logger - :members: - -fastNLP.io.model\_io ---------------------- +fastNLP.io.model\_io module +--------------------------- .. automodule:: fastNLP.io.model_io :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.io :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst index 7452fdf6..3ebf9608 100644 --- a/docs/source/fastNLP.models.rst +++ b/docs/source/fastNLP.models.rst @@ -1,42 +1,110 @@ -fastNLP.models -=============== +fastNLP.models package +====================== -fastNLP.models.base\_model ---------------------------- +Submodules +---------- + +fastNLP.models.base\_model module +--------------------------------- .. automodule:: fastNLP.models.base_model :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.bert module +-------------------------- -fastNLP.models.biaffine\_parser --------------------------------- +.. automodule:: fastNLP.models.bert + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.biaffine\_parser module +-------------------------------------- .. automodule:: fastNLP.models.biaffine_parser :members: + :undoc-members: + :show-inheritance: -fastNLP.models.char\_language\_model -------------------------------------- +fastNLP.models.char\_language\_model module +------------------------------------------- .. automodule:: fastNLP.models.char_language_model :members: + :undoc-members: + :show-inheritance: -fastNLP.models.cnn\_text\_classification ------------------------------------------ +fastNLP.models.cnn\_text\_classification module +----------------------------------------------- .. automodule:: fastNLP.models.cnn_text_classification :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_controller module +-------------------------------------- + +.. automodule:: fastNLP.models.enas_controller + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_model module +--------------------------------- + +.. automodule:: fastNLP.models.enas_model + :members: + :undoc-members: + :show-inheritance: -fastNLP.models.sequence\_modeling ----------------------------------- +fastNLP.models.enas\_trainer module +----------------------------------- + +.. automodule:: fastNLP.models.enas_trainer + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_utils module +--------------------------------- + +.. automodule:: fastNLP.models.enas_utils + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.sequence\_modeling module +---------------------------------------- .. automodule:: fastNLP.models.sequence_modeling :members: + :undoc-members: + :show-inheritance: -fastNLP.models.snli --------------------- +fastNLP.models.snli module +-------------------------- .. automodule:: fastNLP.models.snli :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.star\_transformer module +--------------------------------------- + +.. automodule:: fastNLP.models.star_transformer + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.models :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.aggregator.rst b/docs/source/fastNLP.modules.aggregator.rst index 073da4a5..63d351e4 100644 --- a/docs/source/fastNLP.modules.aggregator.rst +++ b/docs/source/fastNLP.modules.aggregator.rst @@ -1,36 +1,54 @@ -fastNLP.modules.aggregator -=========================== +fastNLP.modules.aggregator package +================================== -fastNLP.modules.aggregator.attention -------------------------------------- +Submodules +---------- + +fastNLP.modules.aggregator.attention module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.attention :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.avg\_pool -------------------------------------- +fastNLP.modules.aggregator.avg\_pool module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.avg_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.kmax\_pool --------------------------------------- +fastNLP.modules.aggregator.kmax\_pool module +-------------------------------------------- .. automodule:: fastNLP.modules.aggregator.kmax_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.max\_pool -------------------------------------- +fastNLP.modules.aggregator.max\_pool module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.max_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.self\_attention -------------------------------------------- +fastNLP.modules.aggregator.self\_attention module +------------------------------------------------- .. automodule:: fastNLP.modules.aggregator.self_attention :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.aggregator :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.decoder.rst b/docs/source/fastNLP.modules.decoder.rst index 6844543a..25602b2c 100644 --- a/docs/source/fastNLP.modules.decoder.rst +++ b/docs/source/fastNLP.modules.decoder.rst @@ -1,18 +1,30 @@ -fastNLP.modules.decoder -======================== +fastNLP.modules.decoder package +=============================== -fastNLP.modules.decoder.CRF ----------------------------- +Submodules +---------- + +fastNLP.modules.decoder.CRF module +---------------------------------- .. automodule:: fastNLP.modules.decoder.CRF :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.decoder.MLP ----------------------------- +fastNLP.modules.decoder.MLP module +---------------------------------- .. automodule:: fastNLP.modules.decoder.MLP :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.decoder :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst index ea8fc699..ab93a169 100644 --- a/docs/source/fastNLP.modules.encoder.rst +++ b/docs/source/fastNLP.modules.encoder.rst @@ -1,60 +1,94 @@ -fastNLP.modules.encoder -======================== +fastNLP.modules.encoder package +=============================== -fastNLP.modules.encoder.char\_embedding ----------------------------------------- +Submodules +---------- + +fastNLP.modules.encoder.char\_embedding module +---------------------------------------------- .. automodule:: fastNLP.modules.encoder.char_embedding :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.conv ------------------------------ +fastNLP.modules.encoder.conv module +----------------------------------- .. automodule:: fastNLP.modules.encoder.conv :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.conv\_maxpool --------------------------------------- +fastNLP.modules.encoder.conv\_maxpool module +-------------------------------------------- .. automodule:: fastNLP.modules.encoder.conv_maxpool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.embedding ----------------------------------- +fastNLP.modules.encoder.embedding module +---------------------------------------- .. automodule:: fastNLP.modules.encoder.embedding :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.linear -------------------------------- +fastNLP.modules.encoder.linear module +------------------------------------- .. automodule:: fastNLP.modules.encoder.linear :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.lstm ------------------------------ +fastNLP.modules.encoder.lstm module +----------------------------------- .. automodule:: fastNLP.modules.encoder.lstm :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.masked\_rnn ------------------------------------- +fastNLP.modules.encoder.masked\_rnn module +------------------------------------------ .. automodule:: fastNLP.modules.encoder.masked_rnn :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.transformer ------------------------------------- +fastNLP.modules.encoder.star\_transformer module +------------------------------------------------ + +.. automodule:: fastNLP.modules.encoder.star_transformer + :members: + :undoc-members: + :show-inheritance: + +fastNLP.modules.encoder.transformer module +------------------------------------------ .. automodule:: fastNLP.modules.encoder.transformer :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.variational\_rnn ------------------------------------------ +fastNLP.modules.encoder.variational\_rnn module +----------------------------------------------- .. automodule:: fastNLP.modules.encoder.variational_rnn :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.encoder :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index 965fb27d..57858176 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -1,5 +1,8 @@ -fastNLP.modules -================ +fastNLP.modules package +======================= + +Subpackages +----------- .. toctree:: @@ -7,24 +10,38 @@ fastNLP.modules fastNLP.modules.decoder fastNLP.modules.encoder -fastNLP.modules.dropout ------------------------- +Submodules +---------- + +fastNLP.modules.dropout module +------------------------------ .. automodule:: fastNLP.modules.dropout :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.other\_modules -------------------------------- +fastNLP.modules.other\_modules module +------------------------------------- .. automodule:: fastNLP.modules.other_modules :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.utils ----------------------- +fastNLP.modules.utils module +---------------------------- .. automodule:: fastNLP.modules.utils :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 61882359..6348c9a6 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -1,13 +1,22 @@ -fastNLP -======== +fastNLP package +=============== + +Subpackages +----------- .. toctree:: fastNLP.api + fastNLP.automl fastNLP.core fastNLP.io fastNLP.models fastNLP.modules +Module contents +--------------- + .. automodule:: fastNLP :members: + :undoc-members: + :show-inheritance: diff --git a/fastNLP/api/__init__.py b/fastNLP/api/__init__.py index a21a4c42..ae31b80b 100644 --- a/fastNLP/api/__init__.py +++ b/fastNLP/api/__init__.py @@ -1 +1,4 @@ +""" + 这是 API 部分的注释 +""" from .api import CWS, POS, Parser diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 53a80131..b001629c 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -1,3 +1,7 @@ +""" +API.API 的文档 + +""" import warnings import torch @@ -184,17 +188,17 @@ class CWS(API): """ 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 分词文件应该为: - 1 编者按 编者按 NN O 11 nmod:topic - 2 : : PU O 11 punct - 3 7月 7月 NT DATE 4 compound:nn - 4 12日 12日 NT DATE 11 nmod:tmod - 5 , , PU O 11 punct - - 1 这 这 DT O 3 det - 2 款 款 M O 1 mark:clf - 3 飞行 飞行 NN O 8 nsubj - 4 从 从 P O 5 case - 5 外型 外型 NN O 8 nmod:prep + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep 以空行分割两个句子,有内容的每行有7列。 :param filepath: str, 文件路径路径。 diff --git a/fastNLP/automl/enas_trainer.py b/fastNLP/automl/enas_trainer.py index 7c0da752..061d604c 100644 --- a/fastNLP/automl/enas_trainer.py +++ b/fastNLP/automl/enas_trainer.py @@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): """ :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: - - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + :return results: 返回一个字典类型的数据, + 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 76a34655..6cbfc20f 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -272,7 +272,7 @@ class DataSet(object): :param func: a function that takes an instance as input. :param str new_field_name: If not None, results of the function will be stored as a new field. - :param **kwargs: Accept parameters will be + :param kwargs: Accept parameters will be (1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. :return results: if new_field_name is not passed, returned values of the function over all instances. diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 10fbbebe..caf2a1cf 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -48,12 +48,16 @@ class PadderBase: class AutoPadder(PadderBase): """ 根据contents的数据自动判定是否需要做padding。 - (1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 - 型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding - (2) 如果元素类型为(np.int64, np.float64), - (2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding - (2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 - 如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad + + 1 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 + 型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding + + 2 如果元素类型为(np.int64, np.float64), + + 2.1 如果该field的内容只有一个,比如为sequence_length, 则不进行padding + + 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 + 如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad """ def __init__(self, pad_val=0): """ diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 5ac52e3f..fff992cc 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,13 +1,12 @@ class Instance(object): """An Instance is an example of data. - Example:: - ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) - ins["field_1"] - >>[1, 1, 1] - ins.add_field("field_3", [3, 3, 3]) - - :param fields: a dict of (str: list). - + Example:: + + ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) + ins["field_1"] + >>[1, 1, 1] + ins.add_field("field_3", [3, 3, 3]) + """ def __init__(self, **fields): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index b52244e5..6b0b4460 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs): :param predict: Tensor, model output :param truth: Tensor, truth from dataset - :param **kwargs: extra arguments + :param kwargs: extra arguments :return predict , truth: predict & truth after processing """ return predict.view(-1, predict.size()[-1]), truth.view(-1, ) @@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs): :param predict: Tensor, [batch_size , max_len , tag_size] :param truth: Tensor, [batch_size , max_len] - :param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. + :param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. :return predict , truth: predict & truth after processing """ diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 5687cc85..314be0d9 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -17,66 +17,72 @@ class MetricBase(object): """Base class for all metrics. 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 + evaluate(xxx)中传入的是一个batch的数据。 + get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 + 以分类问题中,Accuracy计算为例 - 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy - class Model(nn.Module): - def __init__(xxx): - # do something - def forward(self, xxx): - # do something - return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes + 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: + + class Model(nn.Module): + def __init__(xxx): + # do something + def forward(self, xxx): + # do something + return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes + 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target - 对应的AccMetric可以按如下的定义 - # version1, 只使用这一次 - class AccMetric(MetricBase): - def __init__(self): - super().__init__() - - # 根据你的情况自定义指标 - self.corr_num = 0 - self.total = 0 - - def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value - # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric - self.total += label.size(0) - self.corr_num += label.eq(pred).sum().item() - - def get_metric(self, reset=True): # 在这里定义如何计算metric - acc = self.corr_num/self.total - if reset: # 是否清零以便重新计算 + 对应的AccMetric可以按如下的定义, version1, 只使用这一次:: + + class AccMetric(MetricBase): + def __init__(self): + super().__init__() + + # 根据你的情况自定义指标 self.corr_num = 0 self.total = 0 - return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 - - - # version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred - class AccMetric(MetricBase): - def __init__(self, label=None, pred=None): - # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, - # acc_metric = AccMetric(label='y', pred='pred_y')即可。 - # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 - # 应的的值 - super().__init__() - self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 - # 如果没有注册该则效果与version1就是一样的 - - # 根据你的情况自定义指标 - self.corr_num = 0 - self.total = 0 - - def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 - # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric - self.total += label.size(0) - self.corr_num += label.eq(pred).sum().item() - - def get_metric(self, reset=True): # 在这里定义如何计算metric - acc = self.corr_num/self.total - if reset: # 是否清零以便重新计算 + + def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + + version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: + + class AccMetric(MetricBase): + def __init__(self, label=None, pred=None): + # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, + # acc_metric = AccMetric(label='y', pred='pred_y')即可。 + # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 + # 应的的值 + super().__init__() + self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 + # 如果没有注册该则效果与version1就是一样的 + + # 根据你的情况自定义指标 self.corr_num = 0 self.total = 0 - return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. @@ -84,12 +90,12 @@ class MetricBase(object): ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. ``MetricBase`` will do the following type checks: - 1. whether self.evaluate has varargs, which is not supported. - 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. - 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. + 1. whether self.evaluate has varargs, which is not supported. + 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. + 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. Besides, before passing params into self.evaluate, this function will filter out params from output_dict and - target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering + target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering will be conducted.) """ @@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase): """ 在序列标注问题中,以span的方式计算F, pre, rec. 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) - ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 - 最后得到的metric结果为 - { - 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 - 'pre': xxx, - 'rec':xxx - } - 若only_gross=False, 即还会返回各个label的metric统计值 + ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 + 最后得到的metric结果为:: + { - 'f': xxx, - 'pre': xxx, - 'rec':xxx, - 'f-label': xxx, - 'pre-label': xxx, - 'rec-label':xxx, - ... - } + 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 + 'pre': xxx, + 'rec':xxx + } + + 若only_gross=False, 即还会返回各个label的metric统计值:: + + { + 'f': xxx, + 'pre': xxx, + 'rec':xxx, + 'f-label': xxx, + 'pre-label': xxx, + 'rec-label':xxx, + ... + } """ def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, @@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): """ 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B + + +-------+---------+----------+----------+---------+---------+ | | next_B | next_M | next_E | next_S | end | - |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| - | start | 合法 | next_M=B | next_E=S | 合法 | - | + +=======+=========+==========+==========+=========+=========+ + | start | 合法 | next_M=B | next_E=S | 合法 | -- | + +-------+---------+----------+----------+---------+---------+ | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | + +-------+---------+----------+----------+---------+---------+ | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | + +-------+---------+----------+----------+---------+---------+ | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | + +-------+---------+----------+----------+---------+---------+ | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | + +-------+---------+----------+----------+---------+---------+ + 举例: prediction为BSEMS,会被认为是SSSSS. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b45dd148..250cfdb0 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -66,28 +66,28 @@ class Trainer(object): 不足,通过设置batch_size=32, update_every=4达到目的 """ super(Trainer, self).__init__() - + if not isinstance(train_data, DataSet): raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") - + # check metrics and dev_data if (not metrics) and dev_data is not None: raise ValueError("No metric for dev_data evaluation.") if metrics and (dev_data is None): raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") - + # check update every - assert update_every>=1, "update_every must be no less than 1." + assert update_every >= 1, "update_every must be no less than 1." self.update_every = int(update_every) - + # check save_path if not (save_path is None or isinstance(save_path, str)): raise ValueError("save_path can only be None or `str`.") # prepare evaluate metrics = _prepare_metrics(metrics) - + # parse metric_key # increase_better is True. It means the exp result gets better if the indicator increases. # It is true by default. @@ -97,19 +97,19 @@ class Trainer(object): self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key elif len(metrics) > 0: self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') - + # prepare loss losser = _prepare_losser(loss) - + # sampler check if sampler is not None and not isinstance(sampler, BaseSampler): raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) - + if check_code_level > -1: _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, metric_key=metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) - + self.train_data = train_data self.dev_data = dev_data # If None, No validation. self.model = model @@ -120,7 +120,7 @@ class Trainer(object): self.use_cuda = bool(use_cuda) self.save_path = save_path self.print_every = int(print_every) - self.validate_every = int(validate_every) if validate_every!=0 else -1 + self.validate_every = int(validate_every) if validate_every != 0 else -1 self.best_metric_indicator = None self.best_dev_epoch = None self.best_dev_step = None @@ -129,19 +129,19 @@ class Trainer(object): self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.n_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * self.n_epochs - + len(self.train_data) % self.batch_size != 0)) * self.n_epochs + if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer else: if optimizer is None: optimizer = Adam(lr=0.01, weight_decay=0) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) - + self.use_tqdm = use_tqdm self.pbar = None self.print_every = abs(self.print_every) - + if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, @@ -149,14 +149,13 @@ class Trainer(object): batch_size=self.batch_size, use_cuda=self.use_cuda, verbose=0) - + self.step = 0 self.start_time = None # start timestamp - + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) - - + def train(self, load_best_model=True): """ @@ -185,14 +184,15 @@ class Trainer(object): 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 - 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: + 最好的模型参数。 + :return results: 返回一个字典类型的数据, + 内含以下内容:: - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} @@ -205,21 +205,22 @@ class Trainer(object): self.model = self.model.cuda() self._model_device = self.model.parameters().__next__().device self._mode(self.model, is_test=False) - + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) start_time = time.time() print("training epochs started " + self.start_time, flush=True) - + try: self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() except (CallbackException, KeyboardInterrupt) as e: self.callback_manager.on_exception(e) - + if self.dev_data is not None and hasattr(self, 'best_dev_perf'): - print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + - self.tester._format_eval_results(self.best_dev_perf),) + print( + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + + self.tester._format_eval_results(self.best_dev_perf), ) results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step @@ -233,9 +234,9 @@ class Trainer(object): finally: pass results['seconds'] = round(time.time() - start_time, 2) - + return results - + def _train(self): if not self.use_tqdm: from fastNLP.core.utils import pseudo_tqdm as inner_tqdm @@ -244,13 +245,13 @@ class Trainer(object): self.step = 0 self.epoch = 0 start = time.time() - + with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar if isinstance(pbar, tqdm) else None avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) - for epoch in range(1, self.n_epochs+1): + for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -262,22 +263,22 @@ class Trainer(object): # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) - + # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y).mean() avg_loss += loss.item() - loss = loss/self.update_every - + loss = loss / self.update_every + # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) self._grad_backward(loss) self.callback_manager.on_backward_end() - + self._update() self.callback_manager.on_step_end() - - if (self.step+1) % self.print_every == 0: + + if (self.step + 1) % self.print_every == 0: avg_loss = avg_loss / self.print_every if self.use_tqdm: print_output = "loss:{0:<6.5f}".format(avg_loss) @@ -290,34 +291,34 @@ class Trainer(object): pbar.set_postfix_str(print_output) avg_loss = 0 self.callback_manager.on_batch_end() - + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, self.n_steps) + \ - self.tester._format_eval_results(eval_res) + self.tester._format_eval_results(eval_res) pbar.write(eval_str + '\n') - + # ================= mini-batch end ==================== # - + # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # pbar.close() self.pbar = None # ============ tqdm end ============== # - + def _do_validation(self, epoch, step): self.callback_manager.on_valid_begin() res = self.tester.test() - + is_better_eval = False if self._better_eval_result(res): if self.save_path is not None: self._save_model(self.model, - "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) + "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) else: self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} self.best_dev_perf = res @@ -327,7 +328,7 @@ class Trainer(object): # get validation results; adjust optimizer self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) return res - + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -339,21 +340,21 @@ class Trainer(object): model.eval() else: model.train() - + def _update(self): """Perform weight update on a model. """ - if (self.step+1)%self.update_every==0: + if (self.step + 1) % self.update_every == 0: self.optimizer.step() - + def _data_forward(self, network, x): x = _build_args(network.forward, **x) y = network(**x) if not isinstance(y, dict): raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") return y - + def _grad_backward(self, loss): """Compute gradient with link rules. @@ -361,10 +362,10 @@ class Trainer(object): For PyTorch, just do "loss.backward()" """ - if self.step%self.update_every==0: + if self.step % self.update_every == 0: self.model.zero_grad() loss.backward() - + def _compute_loss(self, predict, truth): """Compute loss given prediction and ground truth. @@ -373,7 +374,7 @@ class Trainer(object): :return: a scalar """ return self.losser(predict, truth) - + def _save_model(self, model, model_name, only_param=False): """ 存储不含有显卡信息的state_dict或model :param model: @@ -394,7 +395,7 @@ class Trainer(object): model.cpu() torch.save(model, model_path) model.to(self._model_device) - + def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型 if self.save_path is not None: @@ -409,7 +410,7 @@ class Trainer(object): else: return False return True - + def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. @@ -437,6 +438,7 @@ class Trainer(object): DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 + def _get_value_info(_dict): # given a dict value, return information about this dict's value. Return list of str strs = [] @@ -453,27 +455,28 @@ def _get_value_info(_dict): strs.append(_str) return strs + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_devcie = model.parameters().__next__().device - + batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) # forward check - if batch_count==0: + if batch_count == 0: info_str = "" input_fields = _get_value_info(batch_x) target_fields = _get_value_info(batch_y) - if len(input_fields)>0: + if len(input_fields) > 0: info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(input_fields) info_str += '\n' else: raise RuntimeError("There is no input field.") - if len(target_fields)>0: + if len(target_fields) > 0: info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(target_fields) info_str += '\n' @@ -481,14 +484,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ info_str += 'There is no target field.' print(info_str) _check_forward_error(forward_func=model.forward, dataset=dataset, - batch_x=batch_x, check_level=check_level) - + batch_x=batch_x, check_level=check_level) + refined_batch_x = _build_args(model.forward, **batch_x) pred_dict = model(**refined_batch_x) func_signature = get_func_signature(model.forward) if not isinstance(pred_dict, dict): raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") - + # loss check try: loss = losser(pred_dict, batch_y) @@ -512,7 +515,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ model.zero_grad() if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break - + if dev_data is not None: tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) @@ -526,7 +529,7 @@ def _check_eval_results(metrics, metric_key, metric_list): # metric_list: 多个用来做评价的指标,来自Trainer的初始化 if isinstance(metrics, tuple): loss, metrics = metrics - + if isinstance(metrics, dict): if len(metrics) == 1: # only single metric, just use it @@ -537,7 +540,7 @@ def _check_eval_results(metrics, metric_key, metric_list): if metrics_name not in metrics: raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") metric_dict = metrics[metrics_name] - + if len(metric_dict) == 1: indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] elif len(metric_dict) > 1 and metric_key is None: diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d9141412..fc15166e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -197,17 +197,22 @@ def get_func_signature(func): Given a function or method, return its signature. For example: - (1) function + + 1 function:: + def func(a, b='a', *args): xxxx get_func_signature(func) # 'func(a, b='a', *args)' - (2) method + + 2 method:: + class Demo: def __init__(self): xxx def forward(self, a, b='a', **args) demo = Demo() get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' + :param func: a function or a method :return: str or None """ diff --git a/fastNLP/models/char_language_model.py b/fastNLP/models/char_language_model.py index 5fbde3cc..d5e3359d 100644 --- a/fastNLP/models/char_language_model.py +++ b/fastNLP/models/char_language_model.py @@ -20,16 +20,23 @@ class Highway(nn.Module): class CharLM(nn.Module): """CNN + highway network + LSTM - # Input: + + # Input:: + 4D tensor with shape [batch_size, in_channel, height, width] - # Output: + + # Output:: + 2D Tensor with shape [batch_size, vocab_size] - # Arguments: + + # Arguments:: + char_emb_dim: the size of each character's attention word_emb_dim: the size of each word's attention vocab_size: num of unique words num_char: num of characters use_gpu: True or False + """ def __init__(self, char_emb_dim, word_emb_dim, diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/models/enas_trainer.py index 6b51c897..26b7cd49 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/models/enas_trainer.py @@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): """ :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: - - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + :return results: 返回一个字典类型的数据, + 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {}