@@ -3,6 +3,7 @@ | |||||
# You can set these variables from the command line. | # You can set these variables from the command line. | ||||
SPHINXOPTS = | SPHINXOPTS = | ||||
SPHINXAPIDOC = sphinx-apidoc | |||||
SPHINXBUILD = sphinx-build | SPHINXBUILD = sphinx-build | ||||
SPHINXPROJ = fastNLP | SPHINXPROJ = fastNLP | ||||
SOURCEDIR = source | SOURCEDIR = source | ||||
@@ -12,6 +13,9 @@ BUILDDIR = build | |||||
help: | help: | ||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||||
apidoc: | |||||
@$(SPHINXAPIDOC) -f -o source ../fastNLP | |||||
.PHONY: help Makefile | .PHONY: help Makefile | ||||
# Catch-all target: route all unknown targets to Sphinx using the new | # Catch-all target: route all unknown targets to Sphinx using the new | ||||
@@ -23,9 +23,9 @@ copyright = '2018, xpqiu' | |||||
author = 'xpqiu' | author = 'xpqiu' | ||||
# The short X.Y version | # The short X.Y version | ||||
version = '0.2' | |||||
version = '0.4' | |||||
# The full version, including alpha/beta/rc tags | # The full version, including alpha/beta/rc tags | ||||
release = '0.2' | |||||
release = '0.4' | |||||
# -- General configuration --------------------------------------------------- | # -- General configuration --------------------------------------------------- | ||||
@@ -67,7 +67,7 @@ language = None | |||||
# List of patterns, relative to source directory, that match files and | # List of patterns, relative to source directory, that match files and | ||||
# directories to ignore when looking for source files. | # directories to ignore when looking for source files. | ||||
# This pattern also affects html_static_path and html_extra_path . | # This pattern also affects html_static_path and html_extra_path . | ||||
exclude_patterns = [] | |||||
exclude_patterns = ['modules.rst'] | |||||
# The name of the Pygments (syntax highlighting) style to use. | # The name of the Pygments (syntax highlighting) style to use. | ||||
pygments_style = 'sphinx' | pygments_style = 'sphinx' | ||||
@@ -1,36 +1,62 @@ | |||||
fastNLP.api | |||||
============ | |||||
fastNLP.api package | |||||
=================== | |||||
fastNLP.api.api | |||||
---------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.api.api module | |||||
---------------------- | |||||
.. automodule:: fastNLP.api.api | .. automodule:: fastNLP.api.api | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.converter | |||||
---------------------- | |||||
fastNLP.api.converter module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.api.converter | .. automodule:: fastNLP.api.converter | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.model\_zoo | |||||
----------------------- | |||||
fastNLP.api.examples module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.api.model_zoo | |||||
.. automodule:: fastNLP.api.examples | |||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.pipeline | |||||
--------------------- | |||||
fastNLP.api.pipeline module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.api.pipeline | .. automodule:: fastNLP.api.pipeline | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.processor | |||||
---------------------- | |||||
fastNLP.api.processor module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.api.processor | .. automodule:: fastNLP.api.processor | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.utils module | |||||
------------------------ | |||||
.. automodule:: fastNLP.api.utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.api | .. automodule:: fastNLP.api | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,84 +1,126 @@ | |||||
fastNLP.core | |||||
============= | |||||
fastNLP.core package | |||||
==================== | |||||
fastNLP.core.batch | |||||
------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.core.batch module | |||||
------------------------- | |||||
.. automodule:: fastNLP.core.batch | .. automodule:: fastNLP.core.batch | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.callback module | |||||
---------------------------- | |||||
fastNLP.core.dataset | |||||
--------------------- | |||||
.. automodule:: fastNLP.core.callback | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.dataset module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.dataset | .. automodule:: fastNLP.core.dataset | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.fieldarray | |||||
------------------------ | |||||
fastNLP.core.fieldarray module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.core.fieldarray | .. automodule:: fastNLP.core.fieldarray | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.instance | |||||
---------------------- | |||||
fastNLP.core.instance module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.core.instance | .. automodule:: fastNLP.core.instance | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.losses | |||||
-------------------- | |||||
fastNLP.core.losses module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.core.losses | .. automodule:: fastNLP.core.losses | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.metrics | |||||
--------------------- | |||||
fastNLP.core.metrics module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.metrics | .. automodule:: fastNLP.core.metrics | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.optimizer | |||||
----------------------- | |||||
fastNLP.core.optimizer module | |||||
----------------------------- | |||||
.. automodule:: fastNLP.core.optimizer | .. automodule:: fastNLP.core.optimizer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.predictor | |||||
----------------------- | |||||
fastNLP.core.predictor module | |||||
----------------------------- | |||||
.. automodule:: fastNLP.core.predictor | .. automodule:: fastNLP.core.predictor | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.sampler | |||||
--------------------- | |||||
fastNLP.core.sampler module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.sampler | .. automodule:: fastNLP.core.sampler | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.tester | |||||
-------------------- | |||||
fastNLP.core.tester module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.core.tester | .. automodule:: fastNLP.core.tester | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.trainer | |||||
--------------------- | |||||
fastNLP.core.trainer module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.trainer | .. automodule:: fastNLP.core.trainer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.utils | |||||
------------------- | |||||
fastNLP.core.utils module | |||||
------------------------- | |||||
.. automodule:: fastNLP.core.utils | .. automodule:: fastNLP.core.utils | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.vocabulary | |||||
------------------------ | |||||
fastNLP.core.vocabulary module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.core.vocabulary | .. automodule:: fastNLP.core.vocabulary | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.core | .. automodule:: fastNLP.core | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,42 +1,54 @@ | |||||
fastNLP.io | |||||
=========== | |||||
fastNLP.io package | |||||
================== | |||||
fastNLP.io.base\_loader | |||||
------------------------ | |||||
Submodules | |||||
---------- | |||||
fastNLP.io.base\_loader module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.io.base_loader | .. automodule:: fastNLP.io.base_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.config\_io | |||||
---------------------- | |||||
fastNLP.io.config\_io module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.io.config_io | .. automodule:: fastNLP.io.config_io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.dataset\_loader | |||||
--------------------------- | |||||
fastNLP.io.dataset\_loader module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.io.dataset_loader | .. automodule:: fastNLP.io.dataset_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.embed\_loader | |||||
------------------------- | |||||
fastNLP.io.embed\_loader module | |||||
------------------------------- | |||||
.. automodule:: fastNLP.io.embed_loader | .. automodule:: fastNLP.io.embed_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.logger | |||||
------------------ | |||||
.. automodule:: fastNLP.io.logger | |||||
:members: | |||||
fastNLP.io.model\_io | |||||
--------------------- | |||||
fastNLP.io.model\_io module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.io.model_io | .. automodule:: fastNLP.io.model_io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.io | .. automodule:: fastNLP.io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,42 +1,110 @@ | |||||
fastNLP.models | |||||
=============== | |||||
fastNLP.models package | |||||
====================== | |||||
fastNLP.models.base\_model | |||||
--------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.models.base\_model module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.base_model | .. automodule:: fastNLP.models.base_model | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.bert module | |||||
-------------------------- | |||||
fastNLP.models.biaffine\_parser | |||||
-------------------------------- | |||||
.. automodule:: fastNLP.models.bert | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.biaffine\_parser module | |||||
-------------------------------------- | |||||
.. automodule:: fastNLP.models.biaffine_parser | .. automodule:: fastNLP.models.biaffine_parser | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.char\_language\_model | |||||
------------------------------------- | |||||
fastNLP.models.char\_language\_model module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.models.char_language_model | .. automodule:: fastNLP.models.char_language_model | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.cnn\_text\_classification | |||||
----------------------------------------- | |||||
fastNLP.models.cnn\_text\_classification module | |||||
----------------------------------------------- | |||||
.. automodule:: fastNLP.models.cnn_text_classification | .. automodule:: fastNLP.models.cnn_text_classification | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_controller module | |||||
-------------------------------------- | |||||
.. automodule:: fastNLP.models.enas_controller | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_model module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.enas_model | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.sequence\_modeling | |||||
---------------------------------- | |||||
fastNLP.models.enas\_trainer module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.models.enas_trainer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_utils module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.enas_utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.sequence\_modeling module | |||||
---------------------------------------- | |||||
.. automodule:: fastNLP.models.sequence_modeling | .. automodule:: fastNLP.models.sequence_modeling | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.snli | |||||
-------------------- | |||||
fastNLP.models.snli module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.models.snli | .. automodule:: fastNLP.models.snli | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.star\_transformer module | |||||
--------------------------------------- | |||||
.. automodule:: fastNLP.models.star_transformer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.models | .. automodule:: fastNLP.models | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,36 +1,54 @@ | |||||
fastNLP.modules.aggregator | |||||
=========================== | |||||
fastNLP.modules.aggregator package | |||||
================================== | |||||
fastNLP.modules.aggregator.attention | |||||
------------------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.aggregator.attention module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.attention | .. automodule:: fastNLP.modules.aggregator.attention | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.avg\_pool | |||||
------------------------------------- | |||||
fastNLP.modules.aggregator.avg\_pool module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.avg_pool | .. automodule:: fastNLP.modules.aggregator.avg_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.kmax\_pool | |||||
-------------------------------------- | |||||
fastNLP.modules.aggregator.kmax\_pool module | |||||
-------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.kmax_pool | .. automodule:: fastNLP.modules.aggregator.kmax_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.max\_pool | |||||
------------------------------------- | |||||
fastNLP.modules.aggregator.max\_pool module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.max_pool | .. automodule:: fastNLP.modules.aggregator.max_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.self\_attention | |||||
------------------------------------------- | |||||
fastNLP.modules.aggregator.self\_attention module | |||||
------------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.self_attention | .. automodule:: fastNLP.modules.aggregator.self_attention | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.aggregator | .. automodule:: fastNLP.modules.aggregator | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,18 +1,30 @@ | |||||
fastNLP.modules.decoder | |||||
======================== | |||||
fastNLP.modules.decoder package | |||||
=============================== | |||||
fastNLP.modules.decoder.CRF | |||||
---------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.decoder.CRF module | |||||
---------------------------------- | |||||
.. automodule:: fastNLP.modules.decoder.CRF | .. automodule:: fastNLP.modules.decoder.CRF | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.decoder.MLP | |||||
---------------------------- | |||||
fastNLP.modules.decoder.MLP module | |||||
---------------------------------- | |||||
.. automodule:: fastNLP.modules.decoder.MLP | .. automodule:: fastNLP.modules.decoder.MLP | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.decoder | .. automodule:: fastNLP.modules.decoder | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,60 +1,94 @@ | |||||
fastNLP.modules.encoder | |||||
======================== | |||||
fastNLP.modules.encoder package | |||||
=============================== | |||||
fastNLP.modules.encoder.char\_embedding | |||||
---------------------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.encoder.char\_embedding module | |||||
---------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.char_embedding | .. automodule:: fastNLP.modules.encoder.char_embedding | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.conv | |||||
----------------------------- | |||||
fastNLP.modules.encoder.conv module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.conv | .. automodule:: fastNLP.modules.encoder.conv | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.conv\_maxpool | |||||
-------------------------------------- | |||||
fastNLP.modules.encoder.conv\_maxpool module | |||||
-------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.conv_maxpool | .. automodule:: fastNLP.modules.encoder.conv_maxpool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.embedding | |||||
---------------------------------- | |||||
fastNLP.modules.encoder.embedding module | |||||
---------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.embedding | .. automodule:: fastNLP.modules.encoder.embedding | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.linear | |||||
------------------------------- | |||||
fastNLP.modules.encoder.linear module | |||||
------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.linear | .. automodule:: fastNLP.modules.encoder.linear | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.lstm | |||||
----------------------------- | |||||
fastNLP.modules.encoder.lstm module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.lstm | .. automodule:: fastNLP.modules.encoder.lstm | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.masked\_rnn | |||||
------------------------------------ | |||||
fastNLP.modules.encoder.masked\_rnn module | |||||
------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.masked_rnn | .. automodule:: fastNLP.modules.encoder.masked_rnn | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.transformer | |||||
------------------------------------ | |||||
fastNLP.modules.encoder.star\_transformer module | |||||
------------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.star_transformer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.transformer module | |||||
------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.transformer | .. automodule:: fastNLP.modules.encoder.transformer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.variational\_rnn | |||||
----------------------------------------- | |||||
fastNLP.modules.encoder.variational\_rnn module | |||||
----------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.variational_rnn | .. automodule:: fastNLP.modules.encoder.variational_rnn | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.encoder | .. automodule:: fastNLP.modules.encoder | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,5 +1,8 @@ | |||||
fastNLP.modules | |||||
================ | |||||
fastNLP.modules package | |||||
======================= | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | .. toctree:: | ||||
@@ -7,24 +10,38 @@ fastNLP.modules | |||||
fastNLP.modules.decoder | fastNLP.modules.decoder | ||||
fastNLP.modules.encoder | fastNLP.modules.encoder | ||||
fastNLP.modules.dropout | |||||
------------------------ | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.dropout module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.modules.dropout | .. automodule:: fastNLP.modules.dropout | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.other\_modules | |||||
------------------------------- | |||||
fastNLP.modules.other\_modules module | |||||
------------------------------------- | |||||
.. automodule:: fastNLP.modules.other_modules | .. automodule:: fastNLP.modules.other_modules | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.utils | |||||
---------------------- | |||||
fastNLP.modules.utils module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.modules.utils | .. automodule:: fastNLP.modules.utils | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules | .. automodule:: fastNLP.modules | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,13 +1,22 @@ | |||||
fastNLP | |||||
======== | |||||
fastNLP package | |||||
=============== | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | .. toctree:: | ||||
fastNLP.api | fastNLP.api | ||||
fastNLP.automl | |||||
fastNLP.core | fastNLP.core | ||||
fastNLP.io | fastNLP.io | ||||
fastNLP.models | fastNLP.models | ||||
fastNLP.modules | fastNLP.modules | ||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP | .. automodule:: fastNLP | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1 +1,4 @@ | |||||
""" | |||||
这是 API 部分的注释 | |||||
""" | |||||
from .api import CWS, POS, Parser | from .api import CWS, POS, Parser |
@@ -1,3 +1,7 @@ | |||||
""" | |||||
API.API 的文档 | |||||
""" | |||||
import warnings | import warnings | ||||
import torch | import torch | ||||
@@ -184,17 +188,17 @@ class CWS(API): | |||||
""" | """ | ||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | ||||
分词文件应该为: | 分词文件应该为: | ||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
以空行分割两个句子,有内容的每行有7列。 | 以空行分割两个句子,有内容的每行有7列。 | ||||
:param filepath: str, 文件路径路径。 | :param filepath: str, 文件路径路径。 | ||||
@@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): | |||||
""" | """ | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -281,6 +281,7 @@ class DataSet(object): | |||||
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target | (2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target | ||||
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 | (3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 | ||||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
assert len(self)!=0, "Null DataSet cannot use apply()." | assert len(self)!=0, "Null DataSet cannot use apply()." | ||||
if field_name not in self: | if field_name not in self: | ||||
@@ -48,12 +48,16 @@ class PadderBase: | |||||
class AutoPadder(PadderBase): | class AutoPadder(PadderBase): | ||||
""" | """ | ||||
根据contents的数据自动判定是否需要做padding。 | 根据contents的数据自动判定是否需要做padding。 | ||||
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||||
(2) 如果元素类型为(np.int64, np.float64), | |||||
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||||
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||||
1 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||||
2 如果元素类型为(np.int64, np.float64), | |||||
2.1 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
""" | """ | ||||
@@ -1,13 +1,12 @@ | |||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. | """An Instance is an example of data. | ||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
:param fields: a dict of (str: list). | |||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
@@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs): | |||||
:param predict: Tensor, model output | :param predict: Tensor, model output | ||||
:param truth: Tensor, truth from dataset | :param truth: Tensor, truth from dataset | ||||
:param **kwargs: extra arguments | |||||
:param kwargs: extra arguments | |||||
:return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
""" | """ | ||||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | ||||
@@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs): | |||||
:param predict: Tensor, [batch_size , max_len , tag_size] | :param predict: Tensor, [batch_size , max_len , tag_size] | ||||
:param truth: Tensor, [batch_size , max_len] | :param truth: Tensor, [batch_size , max_len] | ||||
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||||
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||||
:return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
""" | """ | ||||
@@ -17,66 +17,72 @@ class MetricBase(object): | |||||
"""Base class for all metrics. | """Base class for all metrics. | ||||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | ||||
evaluate(xxx)中传入的是一个batch的数据。 | evaluate(xxx)中传入的是一个batch的数据。 | ||||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | ||||
以分类问题中,Accuracy计算为例 | 以分类问题中,Accuracy计算为例 | ||||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||||
class Model(nn.Module): | |||||
def __init__(xxx): | |||||
# do something | |||||
def forward(self, xxx): | |||||
# do something | |||||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: | |||||
class Model(nn.Module): | |||||
def __init__(xxx): | |||||
# do something | |||||
def forward(self, xxx): | |||||
# do something | |||||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | ||||
对应的AccMetric可以按如下的定义 | |||||
# version1, 只使用这一次 | |||||
class AccMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
对应的AccMetric可以按如下的定义, version1, 只使用这一次:: | |||||
class AccMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | self.corr_num = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||||
class AccMetric(MetricBase): | |||||
def __init__(self, label=None, pred=None): | |||||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||||
# 应的的值 | |||||
super().__init__() | |||||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||||
# 如果没有注册该则效果与version1就是一样的 | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: | |||||
class AccMetric(MetricBase): | |||||
def __init__(self, label=None, pred=None): | |||||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||||
# 应的的值 | |||||
super().__init__() | |||||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||||
# 如果没有注册该则效果与version1就是一样的 | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | self.corr_num = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ||||
@@ -84,12 +90,12 @@ class MetricBase(object): | |||||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ||||
``MetricBase`` will do the following type checks: | ``MetricBase`` will do the following type checks: | ||||
1. whether self.evaluate has varargs, which is not supported. | |||||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||||
1. whether self.evaluate has varargs, which is not supported. | |||||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||||
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering | |||||
will be conducted.) | will be conducted.) | ||||
""" | """ | ||||
@@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase): | |||||
""" | """ | ||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | ||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为 | |||||
{ | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值 | |||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为:: | |||||
{ | { | ||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||||
{ | |||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | ||||
@@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): | |||||
""" | """ | ||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | ||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| | next_B | next_M | next_E | next_S | end | | | | next_B | next_M | next_E | next_S | end | | ||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
+=======+=========+==========+==========+=========+=========+ | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | -- | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
举例: | 举例: | ||||
prediction为BSEMS,会被认为是SSSSS. | prediction为BSEMS,会被认为是SSSSS. | ||||
@@ -66,28 +66,28 @@ class Trainer(object): | |||||
不足,通过设置batch_size=32, update_every=4达到目的 | 不足,通过设置batch_size=32, update_every=4达到目的 | ||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | if not isinstance(train_data, DataSet): | ||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
# check metrics and dev_data | # check metrics and dev_data | ||||
if (not metrics) and dev_data is not None: | if (not metrics) and dev_data is not None: | ||||
raise ValueError("No metric for dev_data evaluation.") | raise ValueError("No metric for dev_data evaluation.") | ||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | # check update every | ||||
assert update_every>=1, "update_every must be no less than 1." | |||||
assert update_every >= 1, "update_every must be no less than 1." | |||||
self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
# check save_path | # check save_path | ||||
if not (save_path is None or isinstance(save_path, str)): | if not (save_path is None or isinstance(save_path, str)): | ||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
# prepare evaluate | # prepare evaluate | ||||
metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
# parse metric_key | # parse metric_key | ||||
# increase_better is True. It means the exp result gets better if the indicator increases. | # increase_better is True. It means the exp result gets better if the indicator increases. | ||||
# It is true by default. | # It is true by default. | ||||
@@ -97,19 +97,19 @@ class Trainer(object): | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
elif len(metrics) > 0: | elif len(metrics) > 0: | ||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | ||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
# sampler check | # sampler check | ||||
if sampler is not None and not isinstance(sampler, BaseSampler): | if sampler is not None and not isinstance(sampler, BaseSampler): | ||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if check_code_level > -1: | if check_code_level > -1: | ||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | ||||
metric_key=metric_key, check_level=check_code_level, | metric_key=metric_key, check_level=check_code_level, | ||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
self.train_data = train_data | self.train_data = train_data | ||||
self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
self.model = model | self.model = model | ||||
@@ -120,7 +120,7 @@ class Trainer(object): | |||||
self.use_cuda = bool(use_cuda) | self.use_cuda = bool(use_cuda) | ||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||||
self.validate_every = int(validate_every) if validate_every != 0 else -1 | |||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
self.best_dev_step = None | self.best_dev_step = None | ||||
@@ -129,19 +129,19 @@ class Trainer(object): | |||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
else: | else: | ||||
if optimizer is None: | if optimizer is None: | ||||
optimizer = Adam(lr=0.01, weight_decay=0) | optimizer = Adam(lr=0.01, weight_decay=0) | ||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
self.tester = Tester(model=self.model, | self.tester = Tester(model=self.model, | ||||
data=self.dev_data, | data=self.dev_data, | ||||
@@ -149,14 +149,13 @@ class Trainer(object): | |||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
use_cuda=self.use_cuda, | use_cuda=self.use_cuda, | ||||
verbose=0) | verbose=0) | ||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -185,14 +184,15 @@ class Trainer(object): | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -205,21 +205,22 @@ class Trainer(object): | |||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self._model_device = self.model.parameters().__next__().device | self._model_device = self.model.parameters().__next__().device | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
print( | |||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf), ) | |||||
results['best_eval'] = self.best_dev_perf | results['best_eval'] = self.best_dev_perf | ||||
results['best_epoch'] = self.best_dev_epoch | results['best_epoch'] = self.best_dev_epoch | ||||
results['best_step'] = self.best_dev_step | results['best_step'] = self.best_dev_step | ||||
@@ -233,9 +234,9 @@ class Trainer(object): | |||||
finally: | finally: | ||||
pass | pass | ||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | ||||
@@ -244,13 +245,13 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | self.epoch = 0 | ||||
start = time.time() | start = time.time() | ||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar if isinstance(pbar, tqdm) else None | self.pbar = pbar if isinstance(pbar, tqdm) else None | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
for epoch in range(1, self.n_epochs+1): | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
self.epoch = epoch | self.epoch = epoch | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -262,22 +263,22 @@ class Trainer(object): | |||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | # edit prediction | ||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y).mean() | loss = self._compute_loss(prediction, batch_y).mean() | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss/self.update_every | |||||
loss = loss / self.update_every | |||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
if (self.step+1) % self.print_every == 0: | |||||
if (self.step + 1) % self.print_every == 0: | |||||
avg_loss = avg_loss / self.print_every | avg_loss = avg_loss / self.print_every | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss) | print_output = "loss:{0:<6.5f}".format(avg_loss) | ||||
@@ -290,34 +291,34 @@ class Trainer(object): | |||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str + '\n') | pbar.write(eval_str + '\n') | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
self.pbar = None | self.pbar = None | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | is_better_eval = False | ||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
else: | else: | ||||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | ||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
@@ -327,7 +328,7 @@ class Trainer(object): | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | ||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -339,21 +340,21 @@ class Trainer(object): | |||||
model.eval() | model.eval() | ||||
else: | else: | ||||
model.train() | model.train() | ||||
def _update(self): | def _update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
if (self.step+1)%self.update_every==0: | |||||
if (self.step + 1) % self.update_every == 0: | |||||
self.optimizer.step() | self.optimizer.step() | ||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | ||||
return y | return y | ||||
def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
@@ -361,10 +362,10 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
if self.step%self.update_every==0: | |||||
if self.step % self.update_every == 0: | |||||
self.model.zero_grad() | self.model.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
@@ -373,7 +374,7 @@ class Trainer(object): | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | """ 存储不含有显卡信息的state_dict或model | ||||
:param model: | :param model: | ||||
@@ -394,7 +395,7 @@ class Trainer(object): | |||||
model.cpu() | model.cpu() | ||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.to(self._model_device) | model.to(self._model_device) | ||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
@@ -409,7 +410,7 @@ class Trainer(object): | |||||
else: | else: | ||||
return False | return False | ||||
return True | return True | ||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -437,6 +438,7 @@ class Trainer(object): | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
def _get_value_info(_dict): | def _get_value_info(_dict): | ||||
# given a dict value, return information about this dict's value. Return list of str | # given a dict value, return information about this dict's value. Return list of str | ||||
strs = [] | strs = [] | ||||
@@ -453,27 +455,28 @@ def _get_value_info(_dict): | |||||
strs.append(_str) | strs.append(_str) | ||||
return strs | return strs | ||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
check_level=0): | check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | model_devcie = model.parameters().__next__().device | ||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
if batch_count==0: | |||||
if batch_count == 0: | |||||
info_str = "" | info_str = "" | ||||
input_fields = _get_value_info(batch_x) | input_fields = _get_value_info(batch_x) | ||||
target_fields = _get_value_info(batch_y) | target_fields = _get_value_info(batch_y) | ||||
if len(input_fields)>0: | |||||
if len(input_fields) > 0: | |||||
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | ||||
info_str += "\n".join(input_fields) | info_str += "\n".join(input_fields) | ||||
info_str += '\n' | info_str += '\n' | ||||
else: | else: | ||||
raise RuntimeError("There is no input field.") | raise RuntimeError("There is no input field.") | ||||
if len(target_fields)>0: | |||||
if len(target_fields) > 0: | |||||
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | ||||
info_str += "\n".join(target_fields) | info_str += "\n".join(target_fields) | ||||
info_str += '\n' | info_str += '\n' | ||||
@@ -481,14 +484,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
info_str += 'There is no target field.' | info_str += 'There is no target field.' | ||||
print(info_str) | print(info_str) | ||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | _check_forward_error(forward_func=model.forward, dataset=dataset, | ||||
batch_x=batch_x, check_level=check_level) | |||||
batch_x=batch_x, check_level=check_level) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
pred_dict = model(**refined_batch_x) | pred_dict = model(**refined_batch_x) | ||||
func_signature = get_func_signature(model.forward) | func_signature = get_func_signature(model.forward) | ||||
if not isinstance(pred_dict, dict): | if not isinstance(pred_dict, dict): | ||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | ||||
# loss check | # loss check | ||||
try: | try: | ||||
loss = losser(pred_dict, batch_y) | loss = losser(pred_dict, batch_y) | ||||
@@ -512,7 +515,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | ||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | ||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
@@ -526,7 +529,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | # metric_list: 多个用来做评价的指标,来自Trainer的初始化 | ||||
if isinstance(metrics, tuple): | if isinstance(metrics, tuple): | ||||
loss, metrics = metrics | loss, metrics = metrics | ||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | if len(metrics) == 1: | ||||
# only single metric, just use it | # only single metric, just use it | ||||
@@ -537,7 +540,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
if metrics_name not in metrics: | if metrics_name not in metrics: | ||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | ||||
metric_dict = metrics[metrics_name] | metric_dict = metrics[metrics_name] | ||||
if len(metric_dict) == 1: | if len(metric_dict) == 1: | ||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | ||||
elif len(metric_dict) > 1 and metric_key is None: | elif len(metric_dict) > 1 and metric_key is None: | ||||
@@ -197,17 +197,22 @@ def get_func_signature(func): | |||||
Given a function or method, return its signature. | Given a function or method, return its signature. | ||||
For example: | For example: | ||||
(1) function | |||||
1 function:: | |||||
def func(a, b='a', *args): | def func(a, b='a', *args): | ||||
xxxx | xxxx | ||||
get_func_signature(func) # 'func(a, b='a', *args)' | get_func_signature(func) # 'func(a, b='a', *args)' | ||||
(2) method | |||||
2 method:: | |||||
class Demo: | class Demo: | ||||
def __init__(self): | def __init__(self): | ||||
xxx | xxx | ||||
def forward(self, a, b='a', **args) | def forward(self, a, b='a', **args) | ||||
demo = Demo() | demo = Demo() | ||||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | ||||
:param func: a function or a method | :param func: a function or a method | ||||
:return: str or None | :return: str or None | ||||
""" | """ | ||||
@@ -26,10 +26,10 @@ class ConfigLoader(BaseLoader): | |||||
:param str file_path: the path of config file | :param str file_path: the path of config file | ||||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | :param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | ||||
Example:: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
Example:: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
""" | """ | ||||
assert isinstance(sections, dict) | assert isinstance(sections, dict) | ||||
@@ -9,7 +9,7 @@ from fastNLP.io.base_loader import DataLoaderRegister | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
"""Create an DataSet instance that contains no labels. | """Create an DataSet instance that contains no labels. | ||||
:param data: list of list of strings, [num_examples, *]. | |||||
:param data: list of list of strings, [num_examples, \*]. | |||||
Example:: | Example:: | ||||
[ | [ | ||||
@@ -28,7 +28,7 @@ def convert_seq_dataset(data): | |||||
def convert_seq2tag_dataset(data): | def convert_seq2tag_dataset(data): | ||||
"""Convert list of data into DataSet. | """Convert list of data into DataSet. | ||||
:param data: list of list of strings, [num_examples, *]. | |||||
:param data: list of list of strings, [num_examples, \*]. | |||||
Example:: | Example:: | ||||
[ | [ | ||||
@@ -48,7 +48,7 @@ def convert_seq2tag_dataset(data): | |||||
def convert_seq2seq_dataset(data): | def convert_seq2seq_dataset(data): | ||||
"""Convert list of data into DataSet. | """Convert list of data into DataSet. | ||||
:param data: list of list of strings, [num_examples, *]. | |||||
:param data: list of list of strings, [num_examples, \*]. | |||||
Example:: | Example:: | ||||
[ | [ | ||||
@@ -177,18 +177,18 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
class DummyPOSReader(DataSetLoader): | class DummyPOSReader(DataSetLoader): | ||||
"""A simple reader for a dummy POS tagging dataset. | """A simple reader for a dummy POS tagging dataset. | ||||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | |||||
In these datasets, each line are divided by "\\\\t". The first Col is the vocabulary and the second | |||||
Col is the label. Different sentence are divided by an empty line. | Col is the label. Different sentence are divided by an empty line. | ||||
E.g:: | |||||
E.g:: | |||||
Tom label1 | |||||
and label2 | |||||
Jerry label1 | |||||
. label3 | |||||
(separated by an empty line) | |||||
Hello label4 | |||||
world label5 | |||||
! label3 | |||||
Tom label1 | |||||
and label2 | |||||
Jerry label1 | |||||
. label3 | |||||
(separated by an empty line) | |||||
Hello label4 | |||||
world label5 | |||||
! label3 | |||||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | ||||
""" | """ | ||||
@@ -200,11 +200,13 @@ class DummyPOSReader(DataSetLoader): | |||||
""" | """ | ||||
:return data: three-level list | :return data: three-level list | ||||
Example:: | Example:: | ||||
[ | [ | ||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | [ [word_11, word_12, ...], [label_1, label_1, ...] ], | ||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | [ [word_21, word_22, ...], [label_2, label_1, ...] ], | ||||
... | ... | ||||
] | ] | ||||
""" | """ | ||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
@@ -550,6 +552,7 @@ class SNLIDataSetReader(DataSetLoader): | |||||
:param data: A 3D tensor. | :param data: A 3D tensor. | ||||
Example:: | Example:: | ||||
[ | [ | ||||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | ||||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | ||||
@@ -647,7 +650,7 @@ class NaiveCWSReader(DataSetLoader): | |||||
例如:: | 例如:: | ||||
这是 fastNLP , 一个 非常 good 的 包 . | 这是 fastNLP , 一个 非常 good 的 包 . | ||||
或者,即每个part后面还有一个pos tag | 或者,即每个part后面还有一个pos tag | ||||
例如:: | 例如:: | ||||
@@ -661,12 +664,15 @@ class NaiveCWSReader(DataSetLoader): | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | ||||
""" | """ | ||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
允许使用的情况有(默认以\\\\t或空格作为seg):: | |||||
这是 fastNLP , 一个 非常 good 的 包 . | 这是 fastNLP , 一个 非常 good 的 包 . | ||||
和 | |||||
和:: | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | ||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | ||||
:param filepath: | :param filepath: | ||||
:param in_word_splitter: | :param in_word_splitter: | ||||
:param cut_long_sent: | :param cut_long_sent: | ||||
@@ -737,11 +743,12 @@ class ZhConllPOSReader(object): | |||||
def load(self, path): | def load(self, path): | ||||
""" | """ | ||||
返回的DataSet, 包含以下的field | |||||
返回的DataSet, 包含以下的field:: | |||||
words:list of str, | words:list of str, | ||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | ||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
:: | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | 1 编者按 编者按 NN O 11 nmod:topic | ||||
2 : : PU O 11 punct | 2 : : PU O 11 punct | ||||
@@ -132,7 +132,7 @@ class EmbedLoader(BaseLoader): | |||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | ||||
""" | """ | ||||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | ||||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | ||||
:param embed_filepath: str, where to read pretrain embedding | :param embed_filepath: str, where to read pretrain embedding | ||||
@@ -31,16 +31,18 @@ class ModelLoader(BaseLoader): | |||||
class ModelSaver(object): | class ModelSaver(object): | ||||
"""Save a model | """Save a model | ||||
Example:: | |||||
:param str save_path: the path to the saving directory. | |||||
Example:: | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | |||||
:param save_path: the path to the saving directory. | |||||
""" | |||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model, param_only=True): | def save_pytorch(self, model, param_only=True): | ||||
@@ -20,16 +20,23 @@ class Highway(nn.Module): | |||||
class CharLM(nn.Module): | class CharLM(nn.Module): | ||||
"""CNN + highway network + LSTM | """CNN + highway network + LSTM | ||||
# Input: | |||||
# Input:: | |||||
4D tensor with shape [batch_size, in_channel, height, width] | 4D tensor with shape [batch_size, in_channel, height, width] | ||||
# Output: | |||||
# Output:: | |||||
2D Tensor with shape [batch_size, vocab_size] | 2D Tensor with shape [batch_size, vocab_size] | ||||
# Arguments: | |||||
# Arguments:: | |||||
char_emb_dim: the size of each character's attention | char_emb_dim: the size of each character's attention | ||||
word_emb_dim: the size of each word's attention | word_emb_dim: the size of each word's attention | ||||
vocab_size: num of unique words | vocab_size: num of unique words | ||||
num_char: num of characters | num_char: num of characters | ||||
use_gpu: True or False | use_gpu: True or False | ||||
""" | """ | ||||
def __init__(self, char_emb_dim, word_emb_dim, | def __init__(self, char_emb_dim, word_emb_dim, | ||||
@@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): | |||||
""" | """ | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -1,6 +1,5 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder as Decoder | from fastNLP.modules import decoder as Decoder | ||||
@@ -40,7 +39,7 @@ class ESIM(BaseModel): | |||||
batch_first=self.batch_first, bidirectional=True | batch_first=self.batch_first, bidirectional=True | ||||
) | ) | ||||
self.bi_attention = Aggregator.Bi_Attention() | |||||
self.bi_attention = Aggregator.BiAttention() | |||||
self.mean_pooling = Aggregator.MeanPoolWithMask() | self.mean_pooling = Aggregator.MeanPoolWithMask() | ||||
self.max_pooling = Aggregator.MaxPoolWithMask() | self.max_pooling = Aggregator.MaxPoolWithMask() | ||||
@@ -53,23 +52,23 @@ class ESIM(BaseModel): | |||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | ||||
def forward(self, premise, hypothesis, premise_len, hypothesis_len): | |||||
def forward(self, words1, words2, seq_len1, seq_len2): | |||||
""" Forward function | """ Forward function | ||||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||||
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | |||||
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | |||||
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||||
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||||
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. | |||||
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. | |||||
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | ||||
""" | """ | ||||
premise0 = self.embedding_layer(self.embedding(premise)) | |||||
hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) | |||||
premise0 = self.embedding_layer(self.embedding(words1)) | |||||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | |||||
_BP, _PSL, _HP = premise0.size() | _BP, _PSL, _HP = premise0.size() | ||||
_BH, _HSL, _HH = hypothesis0.size() | _BH, _HSL, _HH = hypothesis0.size() | ||||
_BPL, _PLL = premise_len.size() | |||||
_HPL, _HLL = hypothesis_len.size() | |||||
_BPL, _PLL = seq_len1.size() | |||||
_HPL, _HLL = seq_len2.size() | |||||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | assert _BP == _BH and _BPL == _HPL and _BP == _BPL | ||||
assert _HP == _HH | assert _HP == _HH | ||||
@@ -84,7 +83,7 @@ class ESIM(BaseModel): | |||||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | ||||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | ||||
ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) | |||||
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | |||||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ||||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | ||||
@@ -98,17 +97,18 @@ class ESIM(BaseModel): | |||||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | ||||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | ||||
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | |||||
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | |||||
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | |||||
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] | |||||
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | |||||
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | |||||
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | |||||
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | |||||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | ||||
prediction = F.tanh(self.output(v)) # prediction: [B, N] | |||||
prediction = torch.tanh(self.output(v)) # prediction: [B, N] | |||||
return {'pred': prediction} | return {'pred': prediction} | ||||
def predict(self, premise, hypothesis, premise_len, hypothesis_len): | |||||
return self.forward(premise, hypothesis, premise_len, hypothesis_len) | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | |||||
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] | |||||
return torch.argmax(prediction, dim=-1) | |||||
@@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask | |||||
from .kmax_pool import KMaxPool | from .kmax_pool import KMaxPool | ||||
from .attention import Attention | from .attention import Attention | ||||
from .attention import Bi_Attention | |||||
from .attention import BiAttention | |||||
from .self_attention import SelfAttention | from .self_attention import SelfAttention | ||||
@@ -23,9 +23,9 @@ class Attention(torch.nn.Module): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
class DotAttention(nn.Module): | |||||
def __init__(self, key_size, value_size, dropout=0.1): | def __init__(self, key_size, value_size, dropout=0.1): | ||||
super(DotAtte, self).__init__() | |||||
super(DotAttention, self).__init__() | |||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
@@ -48,7 +48,7 @@ class DotAtte(nn.Module): | |||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
class MultiHeadAtte(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | |||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | ||||
""" | """ | ||||
@@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module): | |||||
:param num_head: int,head的数量。 | :param num_head: int,head的数量。 | ||||
:param dropout: float。 | :param dropout: float。 | ||||
""" | """ | ||||
super(MultiHeadAtte, self).__init__() | |||||
super(MultiHeadAttention, self).__init__() | |||||
self.input_size = input_size | self.input_size = input_size | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
@@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module): | |||||
self.q_in = nn.Linear(input_size, in_size) | self.q_in = nn.Linear(input_size, in_size) | ||||
self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
self.v_in = nn.Linear(input_size, in_size) | self.v_in = nn.Linear(input_size, in_size) | ||||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||||
self.attention = DotAttention(key_size=key_size, value_size=value_size) | |||||
self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
self.drop = TimestepDropout(dropout) | self.drop = TimestepDropout(dropout) | ||||
self.reset_parameters() | self.reset_parameters() | ||||
@@ -109,16 +109,30 @@ class MultiHeadAtte(nn.Module): | |||||
return output | return output | ||||
class Bi_Attention(nn.Module): | |||||
class BiAttention(nn.Module): | |||||
"""Bi Attention module | |||||
Calculate Bi Attention matrix `e` | |||||
.. math:: | |||||
\begin{array}{ll} \\ | |||||
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ | |||||
a_i = | |||||
b_j = | |||||
\end{array} | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(Bi_Attention, self).__init__() | |||||
super(BiAttention, self).__init__() | |||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, in_x1, in_x2, x1_len, x2_len): | def forward(self, in_x1, in_x2, x1_len, x2_len): | ||||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||||
# x1_len: [batch_size, x1_seq_len] | |||||
# x2_len: [batch_size, x2_seq_len] | |||||
""" | |||||
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | |||||
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 | |||||
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 | |||||
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 | |||||
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 | |||||
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | |||||
""" | |||||
assert in_x1.size()[0] == in_x2.size()[0] | assert in_x1.size()[0] == in_x2.size()[0] | ||||
assert in_x1.size()[2] == in_x2.size()[2] | assert in_x1.size()[2] == in_x2.size()[2] | ||||
@@ -1,6 +1,6 @@ | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..aggregator.attention import MultiHeadAttention | |||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
@@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||||
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size) | self.norm1 = nn.LayerNorm(model_size) | ||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | ||||
nn.ReLU(), | nn.ReLU(), | ||||