@@ -3,6 +3,7 @@ | |||
# You can set these variables from the command line. | |||
SPHINXOPTS = | |||
SPHINXAPIDOC = sphinx-apidoc | |||
SPHINXBUILD = sphinx-build | |||
SPHINXPROJ = fastNLP | |||
SOURCEDIR = source | |||
@@ -12,6 +13,9 @@ BUILDDIR = build | |||
help: | |||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
apidoc: | |||
@$(SPHINXAPIDOC) -f -o source ../fastNLP | |||
.PHONY: help Makefile | |||
# Catch-all target: route all unknown targets to Sphinx using the new | |||
@@ -23,9 +23,9 @@ copyright = '2018, xpqiu' | |||
author = 'xpqiu' | |||
# The short X.Y version | |||
version = '0.2' | |||
version = '0.4' | |||
# The full version, including alpha/beta/rc tags | |||
release = '0.2' | |||
release = '0.4' | |||
# -- General configuration --------------------------------------------------- | |||
@@ -67,7 +67,7 @@ language = None | |||
# List of patterns, relative to source directory, that match files and | |||
# directories to ignore when looking for source files. | |||
# This pattern also affects html_static_path and html_extra_path . | |||
exclude_patterns = [] | |||
exclude_patterns = ['modules.rst'] | |||
# The name of the Pygments (syntax highlighting) style to use. | |||
pygments_style = 'sphinx' | |||
@@ -1,36 +1,62 @@ | |||
fastNLP.api | |||
============ | |||
fastNLP.api package | |||
=================== | |||
fastNLP.api.api | |||
---------------- | |||
Submodules | |||
---------- | |||
fastNLP.api.api module | |||
---------------------- | |||
.. automodule:: fastNLP.api.api | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.converter | |||
---------------------- | |||
fastNLP.api.converter module | |||
---------------------------- | |||
.. automodule:: fastNLP.api.converter | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.model\_zoo | |||
----------------------- | |||
fastNLP.api.examples module | |||
--------------------------- | |||
.. automodule:: fastNLP.api.model_zoo | |||
.. automodule:: fastNLP.api.examples | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.pipeline | |||
--------------------- | |||
fastNLP.api.pipeline module | |||
--------------------------- | |||
.. automodule:: fastNLP.api.pipeline | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.processor | |||
---------------------- | |||
fastNLP.api.processor module | |||
---------------------------- | |||
.. automodule:: fastNLP.api.processor | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.utils module | |||
------------------------ | |||
.. automodule:: fastNLP.api.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.api | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,84 +1,126 @@ | |||
fastNLP.core | |||
============= | |||
fastNLP.core package | |||
==================== | |||
fastNLP.core.batch | |||
------------------- | |||
Submodules | |||
---------- | |||
fastNLP.core.batch module | |||
------------------------- | |||
.. automodule:: fastNLP.core.batch | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.callback module | |||
---------------------------- | |||
fastNLP.core.dataset | |||
--------------------- | |||
.. automodule:: fastNLP.core.callback | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.dataset module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.dataset | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.fieldarray | |||
------------------------ | |||
fastNLP.core.fieldarray module | |||
------------------------------ | |||
.. automodule:: fastNLP.core.fieldarray | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.instance | |||
---------------------- | |||
fastNLP.core.instance module | |||
---------------------------- | |||
.. automodule:: fastNLP.core.instance | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.losses | |||
-------------------- | |||
fastNLP.core.losses module | |||
-------------------------- | |||
.. automodule:: fastNLP.core.losses | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.metrics | |||
--------------------- | |||
fastNLP.core.metrics module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.metrics | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.optimizer | |||
----------------------- | |||
fastNLP.core.optimizer module | |||
----------------------------- | |||
.. automodule:: fastNLP.core.optimizer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.predictor | |||
----------------------- | |||
fastNLP.core.predictor module | |||
----------------------------- | |||
.. automodule:: fastNLP.core.predictor | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.sampler | |||
--------------------- | |||
fastNLP.core.sampler module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.sampler | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.tester | |||
-------------------- | |||
fastNLP.core.tester module | |||
-------------------------- | |||
.. automodule:: fastNLP.core.tester | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.trainer | |||
--------------------- | |||
fastNLP.core.trainer module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.utils | |||
------------------- | |||
fastNLP.core.utils module | |||
------------------------- | |||
.. automodule:: fastNLP.core.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.vocabulary | |||
------------------------ | |||
fastNLP.core.vocabulary module | |||
------------------------------ | |||
.. automodule:: fastNLP.core.vocabulary | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.core | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,42 +1,54 @@ | |||
fastNLP.io | |||
=========== | |||
fastNLP.io package | |||
================== | |||
fastNLP.io.base\_loader | |||
------------------------ | |||
Submodules | |||
---------- | |||
fastNLP.io.base\_loader module | |||
------------------------------ | |||
.. automodule:: fastNLP.io.base_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.config\_io | |||
---------------------- | |||
fastNLP.io.config\_io module | |||
---------------------------- | |||
.. automodule:: fastNLP.io.config_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.dataset\_loader | |||
--------------------------- | |||
fastNLP.io.dataset\_loader module | |||
--------------------------------- | |||
.. automodule:: fastNLP.io.dataset_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.embed\_loader | |||
------------------------- | |||
fastNLP.io.embed\_loader module | |||
------------------------------- | |||
.. automodule:: fastNLP.io.embed_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.logger | |||
------------------ | |||
.. automodule:: fastNLP.io.logger | |||
:members: | |||
fastNLP.io.model\_io | |||
--------------------- | |||
fastNLP.io.model\_io module | |||
--------------------------- | |||
.. automodule:: fastNLP.io.model_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,42 +1,110 @@ | |||
fastNLP.models | |||
=============== | |||
fastNLP.models package | |||
====================== | |||
fastNLP.models.base\_model | |||
--------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.models.base\_model module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.base_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.bert module | |||
-------------------------- | |||
fastNLP.models.biaffine\_parser | |||
-------------------------------- | |||
.. automodule:: fastNLP.models.bert | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.biaffine\_parser module | |||
-------------------------------------- | |||
.. automodule:: fastNLP.models.biaffine_parser | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.char\_language\_model | |||
------------------------------------- | |||
fastNLP.models.char\_language\_model module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.models.char_language_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.cnn\_text\_classification | |||
----------------------------------------- | |||
fastNLP.models.cnn\_text\_classification module | |||
----------------------------------------------- | |||
.. automodule:: fastNLP.models.cnn_text_classification | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_controller module | |||
-------------------------------------- | |||
.. automodule:: fastNLP.models.enas_controller | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_model module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.enas_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.sequence\_modeling | |||
---------------------------------- | |||
fastNLP.models.enas\_trainer module | |||
----------------------------------- | |||
.. automodule:: fastNLP.models.enas_trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_utils module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.enas_utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.sequence\_modeling module | |||
---------------------------------------- | |||
.. automodule:: fastNLP.models.sequence_modeling | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.snli | |||
-------------------- | |||
fastNLP.models.snli module | |||
-------------------------- | |||
.. automodule:: fastNLP.models.snli | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.star\_transformer module | |||
--------------------------------------- | |||
.. automodule:: fastNLP.models.star_transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.models | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,36 +1,54 @@ | |||
fastNLP.modules.aggregator | |||
=========================== | |||
fastNLP.modules.aggregator package | |||
================================== | |||
fastNLP.modules.aggregator.attention | |||
------------------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.aggregator.attention module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.attention | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.avg\_pool | |||
------------------------------------- | |||
fastNLP.modules.aggregator.avg\_pool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.avg_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.kmax\_pool | |||
-------------------------------------- | |||
fastNLP.modules.aggregator.kmax\_pool module | |||
-------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.kmax_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.max\_pool | |||
------------------------------------- | |||
fastNLP.modules.aggregator.max\_pool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.max_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.self\_attention | |||
------------------------------------------- | |||
fastNLP.modules.aggregator.self\_attention module | |||
------------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.self_attention | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.aggregator | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,18 +1,30 @@ | |||
fastNLP.modules.decoder | |||
======================== | |||
fastNLP.modules.decoder package | |||
=============================== | |||
fastNLP.modules.decoder.CRF | |||
---------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.decoder.CRF module | |||
---------------------------------- | |||
.. automodule:: fastNLP.modules.decoder.CRF | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.decoder.MLP | |||
---------------------------- | |||
fastNLP.modules.decoder.MLP module | |||
---------------------------------- | |||
.. automodule:: fastNLP.modules.decoder.MLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.decoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,60 +1,94 @@ | |||
fastNLP.modules.encoder | |||
======================== | |||
fastNLP.modules.encoder package | |||
=============================== | |||
fastNLP.modules.encoder.char\_embedding | |||
---------------------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.encoder.char\_embedding module | |||
---------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.char_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.conv | |||
----------------------------- | |||
fastNLP.modules.encoder.conv module | |||
----------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.conv | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.conv\_maxpool | |||
-------------------------------------- | |||
fastNLP.modules.encoder.conv\_maxpool module | |||
-------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.conv_maxpool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.embedding | |||
---------------------------------- | |||
fastNLP.modules.encoder.embedding module | |||
---------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.linear | |||
------------------------------- | |||
fastNLP.modules.encoder.linear module | |||
------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.linear | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.lstm | |||
----------------------------- | |||
fastNLP.modules.encoder.lstm module | |||
----------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.lstm | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.masked\_rnn | |||
------------------------------------ | |||
fastNLP.modules.encoder.masked\_rnn module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.masked_rnn | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.transformer | |||
------------------------------------ | |||
fastNLP.modules.encoder.star\_transformer module | |||
------------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.star_transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.transformer module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.variational\_rnn | |||
----------------------------------------- | |||
fastNLP.modules.encoder.variational\_rnn module | |||
----------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.variational_rnn | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.encoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,5 +1,8 @@ | |||
fastNLP.modules | |||
================ | |||
fastNLP.modules package | |||
======================= | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
@@ -7,24 +10,38 @@ fastNLP.modules | |||
fastNLP.modules.decoder | |||
fastNLP.modules.encoder | |||
fastNLP.modules.dropout | |||
------------------------ | |||
Submodules | |||
---------- | |||
fastNLP.modules.dropout module | |||
------------------------------ | |||
.. automodule:: fastNLP.modules.dropout | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.other\_modules | |||
------------------------------- | |||
fastNLP.modules.other\_modules module | |||
------------------------------------- | |||
.. automodule:: fastNLP.modules.other_modules | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.utils | |||
---------------------- | |||
fastNLP.modules.utils module | |||
---------------------------- | |||
.. automodule:: fastNLP.modules.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,13 +1,22 @@ | |||
fastNLP | |||
======== | |||
fastNLP package | |||
=============== | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
fastNLP.api | |||
fastNLP.automl | |||
fastNLP.core | |||
fastNLP.io | |||
fastNLP.models | |||
fastNLP.modules | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1 +1,4 @@ | |||
""" | |||
这是 API 部分的注释 | |||
""" | |||
from .api import CWS, POS, Parser |
@@ -1,3 +1,7 @@ | |||
""" | |||
API.API 的文档 | |||
""" | |||
import warnings | |||
import torch | |||
@@ -184,17 +188,17 @@ class CWS(API): | |||
""" | |||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | |||
分词文件应该为: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
以空行分割两个句子,有内容的每行有7列。 | |||
:param filepath: str, 文件路径路径。 | |||
@@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): | |||
""" | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||
@@ -272,7 +272,7 @@ class DataSet(object): | |||
:param func: a function that takes an instance as input. | |||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||
:param **kwargs: Accept parameters will be | |||
:param kwargs: Accept parameters will be | |||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||
@@ -48,12 +48,16 @@ class PadderBase: | |||
class AutoPadder(PadderBase): | |||
""" | |||
根据contents的数据自动判定是否需要做padding。 | |||
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||
(2) 如果元素类型为(np.int64, np.float64), | |||
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||
1 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||
2 如果元素类型为(np.int64, np.float64), | |||
2.1 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||
""" | |||
def __init__(self, pad_val=0): | |||
""" | |||
@@ -1,13 +1,12 @@ | |||
class Instance(object): | |||
"""An Instance is an example of data. | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
:param fields: a dict of (str: list). | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
""" | |||
def __init__(self, **fields): | |||
@@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs): | |||
:param predict: Tensor, model output | |||
:param truth: Tensor, truth from dataset | |||
:param **kwargs: extra arguments | |||
:param kwargs: extra arguments | |||
:return predict , truth: predict & truth after processing | |||
""" | |||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | |||
@@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs): | |||
:param predict: Tensor, [batch_size , max_len , tag_size] | |||
:param truth: Tensor, [batch_size , max_len] | |||
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||
:return predict , truth: predict & truth after processing | |||
""" | |||
@@ -17,66 +17,72 @@ class MetricBase(object): | |||
"""Base class for all metrics. | |||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | |||
evaluate(xxx)中传入的是一个batch的数据。 | |||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | |||
以分类问题中,Accuracy计算为例 | |||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||
class Model(nn.Module): | |||
def __init__(xxx): | |||
# do something | |||
def forward(self, xxx): | |||
# do something | |||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: | |||
class Model(nn.Module): | |||
def __init__(xxx): | |||
# do something | |||
def forward(self, xxx): | |||
# do something | |||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | |||
对应的AccMetric可以按如下的定义 | |||
# version1, 只使用这一次 | |||
class AccMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
对应的AccMetric可以按如下的定义, version1, 只使用这一次:: | |||
class AccMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||
class AccMetric(MetricBase): | |||
def __init__(self, label=None, pred=None): | |||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||
# 应的的值 | |||
super().__init__() | |||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||
# 如果没有注册该则效果与version1就是一样的 | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: | |||
class AccMetric(MetricBase): | |||
def __init__(self, label=None, pred=None): | |||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||
# 应的的值 | |||
super().__init__() | |||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||
# 如果没有注册该则效果与version1就是一样的 | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | |||
@@ -84,12 +90,12 @@ class MetricBase(object): | |||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | |||
``MetricBase`` will do the following type checks: | |||
1. whether self.evaluate has varargs, which is not supported. | |||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||
1. whether self.evaluate has varargs, which is not supported. | |||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering | |||
will be conducted.) | |||
""" | |||
@@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase): | |||
""" | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为 | |||
{ | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值 | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为:: | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||
@@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): | |||
""" | |||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
+-------+---------+----------+----------+---------+---------+ | |||
| | next_B | next_M | next_E | next_S | end | | |||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
+=======+=========+==========+==========+=========+=========+ | |||
| start | 合法 | next_M=B | next_E=S | 合法 | -- | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
+-------+---------+----------+----------+---------+---------+ | |||
举例: | |||
prediction为BSEMS,会被认为是SSSSS. | |||
@@ -66,28 +66,28 @@ class Trainer(object): | |||
不足,通过设置batch_size=32, update_every=4达到目的 | |||
""" | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every>=1, "update_every must be no less than 1." | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
# check save_path | |||
if not (save_path is None or isinstance(save_path, str)): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
@@ -97,19 +97,19 @@ class Trainer(object): | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
elif len(metrics) > 0: | |||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, BaseSampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if check_code_level > -1: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
self.model = model | |||
@@ -120,7 +120,7 @@ class Trainer(object): | |||
self.use_cuda = bool(use_cuda) | |||
self.save_path = save_path | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||
self.validate_every = int(validate_every) if validate_every != 0 else -1 | |||
self.best_metric_indicator = None | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
@@ -129,19 +129,19 @@ class Trainer(object): | |||
self.prefetch = prefetch | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
else: | |||
if optimizer is None: | |||
optimizer = Adam(lr=0.01, weight_decay=0) | |||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
@@ -149,14 +149,13 @@ class Trainer(object): | |||
batch_size=self.batch_size, | |||
use_cuda=self.use_cuda, | |||
verbose=0) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
""" | |||
@@ -185,14 +184,15 @@ class Trainer(object): | |||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||
@@ -205,21 +205,22 @@ class Trainer(object): | |||
self.model = self.model.cuda() | |||
self._model_device = self.model.parameters().__next__().device | |||
self._mode(self.model, is_test=False) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
print("training epochs started " + self.start_time, flush=True) | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
self.callback_manager.on_exception(e) | |||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf),) | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf), ) | |||
results['best_eval'] = self.best_dev_perf | |||
results['best_epoch'] = self.best_dev_epoch | |||
results['best_step'] = self.best_dev_step | |||
@@ -233,9 +234,9 @@ class Trainer(object): | |||
finally: | |||
pass | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
return results | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
@@ -244,13 +245,13 @@ class Trainer(object): | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar if isinstance(pbar, tqdm) else None | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
for epoch in range(1, self.n_epochs+1): | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -262,22 +263,22 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss/self.update_every | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
self._grad_backward(loss) | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if (self.step+1) % self.print_every == 0: | |||
if (self.step + 1) % self.print_every == 0: | |||
avg_loss = avg_loss / self.print_every | |||
if self.use_tqdm: | |||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||
@@ -290,34 +291,34 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
is_better_eval = False | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
else: | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self.best_dev_perf = res | |||
@@ -327,7 +328,7 @@ class Trainer(object): | |||
# get validation results; adjust optimizer | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -339,21 +340,21 @@ class Trainer(object): | |||
model.eval() | |||
else: | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if (self.step+1)%self.update_every==0: | |||
if (self.step + 1) % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(network.forward, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -361,10 +362,10 @@ class Trainer(object): | |||
For PyTorch, just do "loss.backward()" | |||
""" | |||
if self.step%self.update_every==0: | |||
if self.step % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
@@ -373,7 +374,7 @@ class Trainer(object): | |||
:return: a scalar | |||
""" | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
@@ -394,7 +395,7 @@ class Trainer(object): | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
if self.save_path is not None: | |||
@@ -409,7 +410,7 @@ class Trainer(object): | |||
else: | |||
return False | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -437,6 +438,7 @@ class Trainer(object): | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
def _get_value_info(_dict): | |||
# given a dict value, return information about this dict's value. Return list of str | |||
strs = [] | |||
@@ -453,27 +455,28 @@ def _get_value_info(_dict): | |||
strs.append(_str) | |||
return strs | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = model.parameters().__next__().device | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
if batch_count==0: | |||
if batch_count == 0: | |||
info_str = "" | |||
input_fields = _get_value_info(batch_x) | |||
target_fields = _get_value_info(batch_y) | |||
if len(input_fields)>0: | |||
if len(input_fields) > 0: | |||
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | |||
info_str += "\n".join(input_fields) | |||
info_str += '\n' | |||
else: | |||
raise RuntimeError("There is no input field.") | |||
if len(target_fields)>0: | |||
if len(target_fields) > 0: | |||
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | |||
info_str += "\n".join(target_fields) | |||
info_str += '\n' | |||
@@ -481,14 +484,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
info_str += 'There is no target field.' | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = get_func_signature(model.forward) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||
# loss check | |||
try: | |||
loss = losser(pred_dict, batch_y) | |||
@@ -512,7 +515,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
model.zero_grad() | |||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
if dev_data is not None: | |||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
@@ -526,7 +529,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | |||
if isinstance(metrics, tuple): | |||
loss, metrics = metrics | |||
if isinstance(metrics, dict): | |||
if len(metrics) == 1: | |||
# only single metric, just use it | |||
@@ -537,7 +540,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
if metrics_name not in metrics: | |||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||
metric_dict = metrics[metrics_name] | |||
if len(metric_dict) == 1: | |||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||
elif len(metric_dict) > 1 and metric_key is None: | |||
@@ -197,17 +197,22 @@ def get_func_signature(func): | |||
Given a function or method, return its signature. | |||
For example: | |||
(1) function | |||
1 function:: | |||
def func(a, b='a', *args): | |||
xxxx | |||
get_func_signature(func) # 'func(a, b='a', *args)' | |||
(2) method | |||
2 method:: | |||
class Demo: | |||
def __init__(self): | |||
xxx | |||
def forward(self, a, b='a', **args) | |||
demo = Demo() | |||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | |||
:param func: a function or a method | |||
:return: str or None | |||
""" | |||
@@ -20,16 +20,23 @@ class Highway(nn.Module): | |||
class CharLM(nn.Module): | |||
"""CNN + highway network + LSTM | |||
# Input: | |||
# Input:: | |||
4D tensor with shape [batch_size, in_channel, height, width] | |||
# Output: | |||
# Output:: | |||
2D Tensor with shape [batch_size, vocab_size] | |||
# Arguments: | |||
# Arguments:: | |||
char_emb_dim: the size of each character's attention | |||
word_emb_dim: the size of each word's attention | |||
vocab_size: num of unique words | |||
num_char: num of characters | |||
use_gpu: True or False | |||
""" | |||
def __init__(self, char_emb_dim, word_emb_dim, | |||
@@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): | |||
""" | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||