@@ -0,0 +1,5 @@ | |||||
include requirements.txt | |||||
include LICENSE | |||||
include README.md | |||||
prune test/ | |||||
prune reproduction/ |
@@ -3,6 +3,7 @@ | |||||
# You can set these variables from the command line. | # You can set these variables from the command line. | ||||
SPHINXOPTS = | SPHINXOPTS = | ||||
SPHINXAPIDOC = sphinx-apidoc | |||||
SPHINXBUILD = sphinx-build | SPHINXBUILD = sphinx-build | ||||
SPHINXPROJ = fastNLP | SPHINXPROJ = fastNLP | ||||
SOURCEDIR = source | SOURCEDIR = source | ||||
@@ -12,6 +13,12 @@ BUILDDIR = build | |||||
help: | help: | ||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||||
apidoc: | |||||
@$(SPHINXAPIDOC) -f -o source ../fastNLP | |||||
server: | |||||
cd build/html && python -m http.server | |||||
.PHONY: help Makefile | .PHONY: help Makefile | ||||
# Catch-all target: route all unknown targets to Sphinx using the new | # Catch-all target: route all unknown targets to Sphinx using the new | ||||
@@ -23,9 +23,9 @@ copyright = '2018, xpqiu' | |||||
author = 'xpqiu' | author = 'xpqiu' | ||||
# The short X.Y version | # The short X.Y version | ||||
version = '0.2' | |||||
version = '0.4' | |||||
# The full version, including alpha/beta/rc tags | # The full version, including alpha/beta/rc tags | ||||
release = '0.2' | |||||
release = '0.4' | |||||
# -- General configuration --------------------------------------------------- | # -- General configuration --------------------------------------------------- | ||||
@@ -67,7 +67,7 @@ language = None | |||||
# List of patterns, relative to source directory, that match files and | # List of patterns, relative to source directory, that match files and | ||||
# directories to ignore when looking for source files. | # directories to ignore when looking for source files. | ||||
# This pattern also affects html_static_path and html_extra_path . | # This pattern also affects html_static_path and html_extra_path . | ||||
exclude_patterns = [] | |||||
exclude_patterns = ['modules.rst'] | |||||
# The name of the Pygments (syntax highlighting) style to use. | # The name of the Pygments (syntax highlighting) style to use. | ||||
pygments_style = 'sphinx' | pygments_style = 'sphinx' | ||||
@@ -1,36 +1,62 @@ | |||||
fastNLP.api | |||||
============ | |||||
fastNLP.api package | |||||
=================== | |||||
fastNLP.api.api | |||||
---------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.api.api module | |||||
---------------------- | |||||
.. automodule:: fastNLP.api.api | .. automodule:: fastNLP.api.api | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.converter | |||||
---------------------- | |||||
fastNLP.api.converter module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.api.converter | .. automodule:: fastNLP.api.converter | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.model\_zoo | |||||
----------------------- | |||||
fastNLP.api.examples module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.api.model_zoo | |||||
.. automodule:: fastNLP.api.examples | |||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.pipeline | |||||
--------------------- | |||||
fastNLP.api.pipeline module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.api.pipeline | .. automodule:: fastNLP.api.pipeline | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.processor | |||||
---------------------- | |||||
fastNLP.api.processor module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.api.processor | .. automodule:: fastNLP.api.processor | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.api.utils module | |||||
------------------------ | |||||
.. automodule:: fastNLP.api.utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.api | .. automodule:: fastNLP.api | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,84 +1,126 @@ | |||||
fastNLP.core | |||||
============= | |||||
fastNLP.core package | |||||
==================== | |||||
fastNLP.core.batch | |||||
------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.core.batch module | |||||
------------------------- | |||||
.. automodule:: fastNLP.core.batch | .. automodule:: fastNLP.core.batch | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.callback module | |||||
---------------------------- | |||||
fastNLP.core.dataset | |||||
--------------------- | |||||
.. automodule:: fastNLP.core.callback | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.dataset module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.dataset | .. automodule:: fastNLP.core.dataset | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.fieldarray | |||||
------------------------ | |||||
fastNLP.core.fieldarray module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.core.fieldarray | .. automodule:: fastNLP.core.fieldarray | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.instance | |||||
---------------------- | |||||
fastNLP.core.instance module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.core.instance | .. automodule:: fastNLP.core.instance | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.losses | |||||
-------------------- | |||||
fastNLP.core.losses module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.core.losses | .. automodule:: fastNLP.core.losses | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.metrics | |||||
--------------------- | |||||
fastNLP.core.metrics module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.metrics | .. automodule:: fastNLP.core.metrics | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.optimizer | |||||
----------------------- | |||||
fastNLP.core.optimizer module | |||||
----------------------------- | |||||
.. automodule:: fastNLP.core.optimizer | .. automodule:: fastNLP.core.optimizer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.predictor | |||||
----------------------- | |||||
fastNLP.core.predictor module | |||||
----------------------------- | |||||
.. automodule:: fastNLP.core.predictor | .. automodule:: fastNLP.core.predictor | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.sampler | |||||
--------------------- | |||||
fastNLP.core.sampler module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.sampler | .. automodule:: fastNLP.core.sampler | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.tester | |||||
-------------------- | |||||
fastNLP.core.tester module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.core.tester | .. automodule:: fastNLP.core.tester | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.trainer | |||||
--------------------- | |||||
fastNLP.core.trainer module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.core.trainer | .. automodule:: fastNLP.core.trainer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.utils | |||||
------------------- | |||||
fastNLP.core.utils module | |||||
------------------------- | |||||
.. automodule:: fastNLP.core.utils | .. automodule:: fastNLP.core.utils | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.core.vocabulary | |||||
------------------------ | |||||
fastNLP.core.vocabulary module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.core.vocabulary | .. automodule:: fastNLP.core.vocabulary | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.core | .. automodule:: fastNLP.core | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,42 +1,54 @@ | |||||
fastNLP.io | |||||
=========== | |||||
fastNLP.io package | |||||
================== | |||||
fastNLP.io.base\_loader | |||||
------------------------ | |||||
Submodules | |||||
---------- | |||||
fastNLP.io.base\_loader module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.io.base_loader | .. automodule:: fastNLP.io.base_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.config\_io | |||||
---------------------- | |||||
fastNLP.io.config\_io module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.io.config_io | .. automodule:: fastNLP.io.config_io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.dataset\_loader | |||||
--------------------------- | |||||
fastNLP.io.dataset\_loader module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.io.dataset_loader | .. automodule:: fastNLP.io.dataset_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.embed\_loader | |||||
------------------------- | |||||
fastNLP.io.embed\_loader module | |||||
------------------------------- | |||||
.. automodule:: fastNLP.io.embed_loader | .. automodule:: fastNLP.io.embed_loader | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.io.logger | |||||
------------------ | |||||
.. automodule:: fastNLP.io.logger | |||||
:members: | |||||
fastNLP.io.model\_io | |||||
--------------------- | |||||
fastNLP.io.model\_io module | |||||
--------------------------- | |||||
.. automodule:: fastNLP.io.model_io | .. automodule:: fastNLP.io.model_io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.io | .. automodule:: fastNLP.io | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,42 +1,110 @@ | |||||
fastNLP.models | |||||
=============== | |||||
fastNLP.models package | |||||
====================== | |||||
fastNLP.models.base\_model | |||||
--------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.models.base\_model module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.base_model | .. automodule:: fastNLP.models.base_model | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.bert module | |||||
-------------------------- | |||||
fastNLP.models.biaffine\_parser | |||||
-------------------------------- | |||||
.. automodule:: fastNLP.models.bert | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.biaffine\_parser module | |||||
-------------------------------------- | |||||
.. automodule:: fastNLP.models.biaffine_parser | .. automodule:: fastNLP.models.biaffine_parser | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.char\_language\_model | |||||
------------------------------------- | |||||
fastNLP.models.char\_language\_model module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.models.char_language_model | .. automodule:: fastNLP.models.char_language_model | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.cnn\_text\_classification | |||||
----------------------------------------- | |||||
fastNLP.models.cnn\_text\_classification module | |||||
----------------------------------------------- | |||||
.. automodule:: fastNLP.models.cnn_text_classification | .. automodule:: fastNLP.models.cnn_text_classification | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_controller module | |||||
-------------------------------------- | |||||
.. automodule:: fastNLP.models.enas_controller | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_model module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.enas_model | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.sequence\_modeling | |||||
---------------------------------- | |||||
fastNLP.models.enas\_trainer module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.models.enas_trainer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.enas\_utils module | |||||
--------------------------------- | |||||
.. automodule:: fastNLP.models.enas_utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.sequence\_modeling module | |||||
---------------------------------------- | |||||
.. automodule:: fastNLP.models.sequence_modeling | .. automodule:: fastNLP.models.sequence_modeling | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.snli | |||||
-------------------- | |||||
fastNLP.models.snli module | |||||
-------------------------- | |||||
.. automodule:: fastNLP.models.snli | .. automodule:: fastNLP.models.snli | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.models.star\_transformer module | |||||
--------------------------------------- | |||||
.. automodule:: fastNLP.models.star_transformer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.models | .. automodule:: fastNLP.models | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,36 +1,54 @@ | |||||
fastNLP.modules.aggregator | |||||
=========================== | |||||
fastNLP.modules.aggregator package | |||||
================================== | |||||
fastNLP.modules.aggregator.attention | |||||
------------------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.aggregator.attention module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.attention | .. automodule:: fastNLP.modules.aggregator.attention | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.avg\_pool | |||||
------------------------------------- | |||||
fastNLP.modules.aggregator.avg\_pool module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.avg_pool | .. automodule:: fastNLP.modules.aggregator.avg_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.kmax\_pool | |||||
-------------------------------------- | |||||
fastNLP.modules.aggregator.kmax\_pool module | |||||
-------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.kmax_pool | .. automodule:: fastNLP.modules.aggregator.kmax_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.max\_pool | |||||
------------------------------------- | |||||
fastNLP.modules.aggregator.max\_pool module | |||||
------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.max_pool | .. automodule:: fastNLP.modules.aggregator.max_pool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.aggregator.self\_attention | |||||
------------------------------------------- | |||||
fastNLP.modules.aggregator.self\_attention module | |||||
------------------------------------------------- | |||||
.. automodule:: fastNLP.modules.aggregator.self_attention | .. automodule:: fastNLP.modules.aggregator.self_attention | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.aggregator | .. automodule:: fastNLP.modules.aggregator | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,18 +1,30 @@ | |||||
fastNLP.modules.decoder | |||||
======================== | |||||
fastNLP.modules.decoder package | |||||
=============================== | |||||
fastNLP.modules.decoder.CRF | |||||
---------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.decoder.CRF module | |||||
---------------------------------- | |||||
.. automodule:: fastNLP.modules.decoder.CRF | .. automodule:: fastNLP.modules.decoder.CRF | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.decoder.MLP | |||||
---------------------------- | |||||
fastNLP.modules.decoder.MLP module | |||||
---------------------------------- | |||||
.. automodule:: fastNLP.modules.decoder.MLP | .. automodule:: fastNLP.modules.decoder.MLP | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.decoder | .. automodule:: fastNLP.modules.decoder | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,60 +1,94 @@ | |||||
fastNLP.modules.encoder | |||||
======================== | |||||
fastNLP.modules.encoder package | |||||
=============================== | |||||
fastNLP.modules.encoder.char\_embedding | |||||
---------------------------------------- | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.encoder.char\_embedding module | |||||
---------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.char_embedding | .. automodule:: fastNLP.modules.encoder.char_embedding | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.conv | |||||
----------------------------- | |||||
fastNLP.modules.encoder.conv module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.conv | .. automodule:: fastNLP.modules.encoder.conv | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.conv\_maxpool | |||||
-------------------------------------- | |||||
fastNLP.modules.encoder.conv\_maxpool module | |||||
-------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.conv_maxpool | .. automodule:: fastNLP.modules.encoder.conv_maxpool | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.embedding | |||||
---------------------------------- | |||||
fastNLP.modules.encoder.embedding module | |||||
---------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.embedding | .. automodule:: fastNLP.modules.encoder.embedding | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.linear | |||||
------------------------------- | |||||
fastNLP.modules.encoder.linear module | |||||
------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.linear | .. automodule:: fastNLP.modules.encoder.linear | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.lstm | |||||
----------------------------- | |||||
fastNLP.modules.encoder.lstm module | |||||
----------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.lstm | .. automodule:: fastNLP.modules.encoder.lstm | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.masked\_rnn | |||||
------------------------------------ | |||||
fastNLP.modules.encoder.masked\_rnn module | |||||
------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.masked_rnn | .. automodule:: fastNLP.modules.encoder.masked_rnn | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.transformer | |||||
------------------------------------ | |||||
fastNLP.modules.encoder.star\_transformer module | |||||
------------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.star_transformer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.transformer module | |||||
------------------------------------------ | |||||
.. automodule:: fastNLP.modules.encoder.transformer | .. automodule:: fastNLP.modules.encoder.transformer | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.encoder.variational\_rnn | |||||
----------------------------------------- | |||||
fastNLP.modules.encoder.variational\_rnn module | |||||
----------------------------------------------- | |||||
.. automodule:: fastNLP.modules.encoder.variational_rnn | .. automodule:: fastNLP.modules.encoder.variational_rnn | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules.encoder | .. automodule:: fastNLP.modules.encoder | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,5 +1,8 @@ | |||||
fastNLP.modules | |||||
================ | |||||
fastNLP.modules package | |||||
======================= | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | .. toctree:: | ||||
@@ -7,24 +10,38 @@ fastNLP.modules | |||||
fastNLP.modules.decoder | fastNLP.modules.decoder | ||||
fastNLP.modules.encoder | fastNLP.modules.encoder | ||||
fastNLP.modules.dropout | |||||
------------------------ | |||||
Submodules | |||||
---------- | |||||
fastNLP.modules.dropout module | |||||
------------------------------ | |||||
.. automodule:: fastNLP.modules.dropout | .. automodule:: fastNLP.modules.dropout | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.other\_modules | |||||
------------------------------- | |||||
fastNLP.modules.other\_modules module | |||||
------------------------------------- | |||||
.. automodule:: fastNLP.modules.other_modules | .. automodule:: fastNLP.modules.other_modules | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
fastNLP.modules.utils | |||||
---------------------- | |||||
fastNLP.modules.utils module | |||||
---------------------------- | |||||
.. automodule:: fastNLP.modules.utils | .. automodule:: fastNLP.modules.utils | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: | |||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP.modules | .. automodule:: fastNLP.modules | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,13 +1,22 @@ | |||||
fastNLP | |||||
======== | |||||
fastNLP package | |||||
=============== | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | .. toctree:: | ||||
fastNLP.api | fastNLP.api | ||||
fastNLP.automl | |||||
fastNLP.core | fastNLP.core | ||||
fastNLP.io | fastNLP.io | ||||
fastNLP.models | fastNLP.models | ||||
fastNLP.modules | fastNLP.modules | ||||
Module contents | |||||
--------------- | |||||
.. automodule:: fastNLP | .. automodule:: fastNLP | ||||
:members: | :members: | ||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,3 +1,41 @@ | |||||
""" | |||||
api.api的介绍文档 | |||||
直接缩进会把上面的文字变成标题 | |||||
空行缩进的写法比较合理 | |||||
比较合理 | |||||
*这里是斜体内容* | |||||
**这里是粗体内容** | |||||
数学公式块 | |||||
.. math:: | |||||
E = mc^2 | |||||
.. note:: | |||||
注解型提示。 | |||||
.. warning:: | |||||
警告型提示。 | |||||
.. seealso:: | |||||
`参考与超链接 <https://willqvq.github.io/doc_guide/%E6%B3%A8%E9%87%8A%E6%8C%87%E5%AF%BC>`_ | |||||
普通代码块需要空一行, Example:: | |||||
from fitlog import fitlog | |||||
fitlog.commit() | |||||
普通下标和上标: | |||||
H\ :sub:`2`\ O | |||||
E = mc\ :sup:`2` | |||||
""" | |||||
import warnings | import warnings | ||||
import torch | import torch | ||||
@@ -103,6 +141,9 @@ class ConllxDataLoader(ConllLoader): | |||||
class API: | class API: | ||||
""" | |||||
这是 API 类的文档 | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
self.pipeline = None | self.pipeline = None | ||||
self._dict = None | self._dict = None | ||||
@@ -148,8 +189,9 @@ class POS(API): | |||||
self.load(model_path, device) | self.load(model_path, device) | ||||
def predict(self, content): | def predict(self, content): | ||||
""" | |||||
"""predict函数的介绍, | |||||
函数介绍的第二句,这句话不会换行 | |||||
:param content: list of list of str. Each string is a token(word). | :param content: list of list of str. Each string is a token(word). | ||||
:return answer: list of list of str. Each string is a tag. | :return answer: list of list of str. Each string is a tag. | ||||
""" | """ | ||||
@@ -215,13 +257,14 @@ class POS(API): | |||||
class CWS(API): | class CWS(API): | ||||
def __init__(self, model_path=None, device='cpu'): | |||||
""" | |||||
中文分词高级接口。 | |||||
""" | |||||
中文分词高级接口。 | |||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||||
""" | |||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(CWS, self).__init__() | super(CWS, self).__init__() | ||||
if model_path is None: | if model_path is None: | ||||
model_path = model_urls['cws'] | model_path = model_urls['cws'] | ||||
@@ -262,18 +305,20 @@ class CWS(API): | |||||
def test(self, filepath): | def test(self, filepath): | ||||
""" | """ | ||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | ||||
分词文件应该为: | |||||
分词文件应该为:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | 1 编者按 编者按 NN O 11 nmod:topic | ||||
2 : : PU O 11 punct | 2 : : PU O 11 punct | ||||
3 7月 7月 NT DATE 4 compound:nn | 3 7月 7月 NT DATE 4 compound:nn | ||||
4 12日 12日 NT DATE 11 nmod:tmod | 4 12日 12日 NT DATE 11 nmod:tmod | ||||
5 , , PU O 11 punct | 5 , , PU O 11 punct | ||||
1 这 这 DT O 3 det | 1 这 这 DT O 3 det | ||||
2 款 款 M O 1 mark:clf | 2 款 款 M O 1 mark:clf | ||||
3 飞行 飞行 NN O 8 nsubj | 3 飞行 飞行 NN O 8 nsubj | ||||
4 从 从 P O 5 case | 4 从 从 P O 5 case | ||||
5 外型 外型 NN O 8 nmod:prep | 5 外型 外型 NN O 8 nmod:prep | ||||
以空行分割两个句子,有内容的每行有7列。 | 以空行分割两个句子,有内容的每行有7列。 | ||||
:param filepath: str, 文件路径路径。 | :param filepath: str, 文件路径路径。 | ||||
@@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): | |||||
""" | """ | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -1,5 +1,5 @@ | |||||
from .batch import Batch | from .batch import Batch | ||||
# from .dataset import DataSet | |||||
from .dataset import DataSet | |||||
from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
@@ -9,5 +9,5 @@ from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSample | |||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ..io.dataset_loader import DataSet | |||||
from .callback import Callback | from .callback import Callback | ||||
from .utils import cache_results |
@@ -21,15 +21,17 @@ class Batch(object): | |||||
:param DataSet dataset: a DataSet object | :param DataSet dataset: a DataSet object | ||||
:param int batch_size: the size of the batch | :param int batch_size: the size of the batch | ||||
:param Sampler sampler: a Sampler object | |||||
:param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler | |||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | ||||
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. | :param bool prefetch: If True, use multiprocessing to fetch next batch when training. | ||||
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): | |||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
@@ -61,6 +61,10 @@ class Callback(object): | |||||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | """If use_tqdm, return trainer's tqdm print bar, else return None.""" | ||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | |||||
def update_every(self): | |||||
"""The model in trainer will update parameters every `update_every` batches.""" | |||||
return self._trainer.update_every | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
@@ -94,12 +98,14 @@ class Callback(object): | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | 每次执行验证机的evaluation后会调用。传入eval_result | ||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | :param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | ||||
:param metric_key: str | :param metric_key: str | ||||
:param optimizer: optimizer passed to trainer | |||||
:param is_better_eval: bool, 当前dev结果是否比之前的好 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -206,7 +212,7 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_valid_end(self, eval_result, metric_key): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
@@ -307,8 +313,8 @@ class EarlyStopCallback(Callback): | |||||
self.patience = patience | self.patience = patience | ||||
self.wait = 0 | self.wait = 0 | ||||
def on_valid_end(self, eval_result, metric_key): | |||||
if not self.trainer._better_eval_result(eval_result): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if not is_better_eval: | |||||
# current result is getting worse | # current result is getting worse | ||||
if self.wait == self.patience: | if self.wait == self.patience: | ||||
raise EarlyStopError("Early stopping raised.") | raise EarlyStopError("Early stopping raised.") | ||||
@@ -484,7 +490,7 @@ class TensorboardCallback(Callback): | |||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_valid_end(self, eval_result, metric_key): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if "metric" in self.options: | if "metric" in self.options: | ||||
for name, metric in eval_result.items(): | for name, metric in eval_result.items(): | ||||
for metric_key, metric_val in metric.items(): | for metric_key, metric_val in metric.items(): | ||||
@@ -6,7 +6,6 @@ from fastNLP.core.fieldarray import AutoPadder | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -90,7 +89,7 @@ class DataSet(object): | |||||
data_set = DataSet() | data_set = DataSet() | ||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target) | |||||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -105,11 +104,6 @@ class DataSet(object): | |||||
raise AttributeError | raise AttributeError | ||||
if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
try: | |||||
reader = DataLoaderRegister.get_reader(item) | |||||
return reader | |||||
except AttributeError: | |||||
raise | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.__dict__ = state | self.__dict__ = state | ||||
@@ -278,12 +272,22 @@ class DataSet(object): | |||||
:param func: a function that takes an instance as input. | :param func: a function that takes an instance as input. | ||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | :param str new_field_name: If not None, results of the function will be stored as a new field. | ||||
:param **kwargs: Accept parameters will be | |||||
:param kwargs: Accept parameters will be | |||||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | (1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | ||||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | ||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
""" | """ | ||||
results = [func(ins) for ins in self._inner_iter()] | |||||
assert len(self)!=0, "Null dataset cannot use .apply()." | |||||
results = [] | |||||
idx = -1 | |||||
try: | |||||
for idx, ins in enumerate(self._inner_iter()): | |||||
results.append(func(ins)) | |||||
except Exception as e: | |||||
if idx!=-1: | |||||
print("Exception happens at the `{}`th instance.".format(idx)) | |||||
raise e | |||||
# results = [func(ins) for ins in self._inner_iter()] | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | raise ValueError("{} always return None.".format(get_func_signature(func=func))) | ||||
@@ -313,16 +317,23 @@ class DataSet(object): | |||||
else: | else: | ||||
return results | return results | ||||
def drop(self, func): | |||||
def drop(self, func, inplace=True): | |||||
"""Drop instances if a condition holds. | """Drop instances if a condition holds. | ||||
:param func: a function that takes an Instance object as input, and returns bool. | :param func: a function that takes an Instance object as input, and returns bool. | ||||
The instance will be dropped if the function returns True. | The instance will be dropped if the function returns True. | ||||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||||
""" | """ | ||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
if inplace: | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
else: | |||||
results = [ins for ins in self if not func(ins)] | |||||
data = DataSet(results) | |||||
for field_name, field in self.field_arrays.items(): | |||||
data.field_arrays[field_name].to(field) | |||||
def split(self, dev_ratio): | def split(self, dev_ratio): | ||||
"""Split the dataset into training and development(validation) set. | """Split the dataset into training and development(validation) set. | ||||
@@ -346,19 +357,8 @@ class DataSet(object): | |||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
for field_name in self.field_arrays: | for field_name in self.field_arrays: | ||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
return train_set, dev_set | return train_set, dev_set | ||||
@@ -376,7 +376,7 @@ class DataSet(object): | |||||
import warnings | import warnings | ||||
warnings.warn('read_csv is deprecated, use CSVLoader instead', | warnings.warn('read_csv is deprecated, use CSVLoader instead', | ||||
category=DeprecationWarning) | category=DeprecationWarning) | ||||
with open(csv_path, "r") as f: | |||||
with open(csv_path, "r", encoding='utf-8') as f: | |||||
start_idx = 0 | start_idx = 0 | ||||
if headers is None: | if headers is None: | ||||
headers = f.readline().rstrip('\r\n') | headers = f.readline().rstrip('\r\n') | ||||
@@ -48,12 +48,16 @@ class PadderBase: | |||||
class AutoPadder(PadderBase): | class AutoPadder(PadderBase): | ||||
""" | """ | ||||
根据contents的数据自动判定是否需要做padding。 | 根据contents的数据自动判定是否需要做padding。 | ||||
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||||
(2) 如果元素类型为(np.int64, np.float64), | |||||
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||||
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||||
1 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||||
2 如果元素类型为(np.int64, np.float64), | |||||
2.1 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
""" | """ | ||||
@@ -383,6 +387,23 @@ class FieldArray(object): | |||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | |||||
""" | |||||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||||
ignore_type | |||||
:param other: FieldArray | |||||
:return: | |||||
""" | |||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||||
self.is_input = other.is_input | |||||
self.is_target = other.is_target | |||||
self.padder = other.padder | |||||
self.dtype = other.dtype | |||||
self.pytype = other.pytype | |||||
self.content_dim = other.content_dim | |||||
self.ignore_type = other.ignore_type | |||||
def is_iterable(content): | def is_iterable(content): | ||||
try: | try: | ||||
@@ -1,13 +1,12 @@ | |||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. | """An Instance is an example of data. | ||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
:param fields: a dict of (str: list). | |||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
@@ -251,7 +251,8 @@ class LossInForward(LossBase): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | ||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
loss = torch.sum(loss) / (loss.view(-1)).size(0) | |||||
# raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
return loss | return loss | ||||
@@ -271,7 +272,7 @@ def squash(predict, truth, **kwargs): | |||||
:param predict: Tensor, model output | :param predict: Tensor, model output | ||||
:param truth: Tensor, truth from dataset | :param truth: Tensor, truth from dataset | ||||
:param **kwargs: extra arguments | |||||
:param kwargs: extra arguments | |||||
:return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
""" | """ | ||||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | ||||
@@ -315,7 +316,7 @@ def mask(predict, truth, **kwargs): | |||||
:param predict: Tensor, [batch_size , max_len , tag_size] | :param predict: Tensor, [batch_size , max_len , tag_size] | ||||
:param truth: Tensor, [batch_size , max_len] | :param truth: Tensor, [batch_size , max_len] | ||||
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||||
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||||
:return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
""" | """ | ||||
@@ -17,66 +17,72 @@ class MetricBase(object): | |||||
"""Base class for all metrics. | """Base class for all metrics. | ||||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | ||||
evaluate(xxx)中传入的是一个batch的数据。 | evaluate(xxx)中传入的是一个batch的数据。 | ||||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | ||||
以分类问题中,Accuracy计算为例 | 以分类问题中,Accuracy计算为例 | ||||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||||
class Model(nn.Module): | |||||
def __init__(xxx): | |||||
# do something | |||||
def forward(self, xxx): | |||||
# do something | |||||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: | |||||
class Model(nn.Module): | |||||
def __init__(xxx): | |||||
# do something | |||||
def forward(self, xxx): | |||||
# do something | |||||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | ||||
对应的AccMetric可以按如下的定义 | |||||
# version1, 只使用这一次 | |||||
class AccMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
对应的AccMetric可以按如下的定义, version1, 只使用这一次:: | |||||
class AccMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | self.corr_num = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||||
class AccMetric(MetricBase): | |||||
def __init__(self, label=None, pred=None): | |||||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||||
# 应的的值 | |||||
super().__init__() | |||||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||||
# 如果没有注册该则效果与version1就是一样的 | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: | |||||
class AccMetric(MetricBase): | |||||
def __init__(self, label=None, pred=None): | |||||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||||
# 应的的值 | |||||
super().__init__() | |||||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||||
# 如果没有注册该则效果与version1就是一样的 | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | self.corr_num = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ||||
@@ -84,14 +90,13 @@ class MetricBase(object): | |||||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ||||
``MetricBase`` will do the following type checks: | ``MetricBase`` will do the following type checks: | ||||
1. whether self.evaluate has varargs, which is not supported. | |||||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||||
1. whether self.evaluate has varargs, which is not supported. | |||||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||||
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering | |||||
will be conducted.) | will be conducted.) | ||||
However, in some cases where type check is not necessary, ``_fast_param_map`` will be used. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -146,21 +151,6 @@ class MetricBase(object): | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
@@ -172,7 +162,6 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | :param pred_dict: usually the output of forward or prediction function | ||||
:param target_dict: usually features set as target.. | :param target_dict: usually features set as target.. | ||||
:return: | :return: | ||||
@@ -180,11 +169,6 @@ class MetricBase(object): | |||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
@@ -262,50 +246,14 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict) == 2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
""" | """ | ||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len, n_classes]) | |||||
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len]) | |||||
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
""" | """ | ||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | ||||
@@ -321,7 +269,7 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if seq_lens is not None: | if seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens) | |||||
else: | else: | ||||
masks = None | masks = None | ||||
@@ -334,14 +282,12 @@ class AccuracyMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
pred = pred.float() | |||||
target = target.float() | |||||
target = target.to(pred) | |||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | |||||
self.total += torch.sum(masks.float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() | |||||
self.total += torch.sum(masks).item() | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
@@ -350,7 +296,7 @@ class AccuracyMetric(MetricBase): | |||||
:param bool reset: whether to recount next time. | :param bool reset: whether to recount next time. | ||||
:return evaluate_result: {"acc": float} | :return evaluate_result: {"acc": float} | ||||
""" | """ | ||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
@@ -441,31 +387,33 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | return [(span[0], (span[1][0], span[1][1]+1)) | ||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | |||||
] | |||||
if span[0] not in ignore_labels] | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
""" | """ | ||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | ||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为 | |||||
{ | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值 | |||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为:: | |||||
{ | { | ||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||||
{ | |||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | ||||
@@ -634,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): | |||||
""" | """ | ||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | ||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| | next_B | next_M | next_E | next_S | end | | | | next_B | next_M | next_E | next_S | end | | ||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
+=======+=========+==========+==========+=========+=========+ | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | -- | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | ||||
+-------+---------+----------+----------+---------+---------+ | |||||
举例: | 举例: | ||||
prediction为BSEMS,会被认为是SSSSS. | prediction为BSEMS,会被认为是SSSSS. | ||||
@@ -34,7 +34,7 @@ class Trainer(object): | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, save_path=None, optimizer=None, | validate_every=-1, dev_data=None, save_path=None, optimizer=None, | ||||
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | ||||
use_cuda=False, callbacks=None): | |||||
use_cuda=False, callbacks=None, update_every=1): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
@@ -62,6 +62,8 @@ class Trainer(object): | |||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | ||||
通过callback机制实现。 | 通过callback机制实现。 | ||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128会导致内存 | |||||
不足,通过设置batch_size=32, update_every=4达到目的 | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -76,6 +78,10 @@ class Trainer(object): | |||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | |||||
assert update_every >= 1, "update_every must be no less than 1." | |||||
self.update_every = int(update_every) | |||||
# check save_path | # check save_path | ||||
if not (save_path is None or isinstance(save_path, str)): | if not (save_path is None or isinstance(save_path, str)): | ||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
@@ -114,7 +120,7 @@ class Trainer(object): | |||||
self.use_cuda = bool(use_cuda) | self.use_cuda = bool(use_cuda) | ||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||||
self.validate_every = int(validate_every) if validate_every != 0 else -1 | |||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
self.best_dev_step = None | self.best_dev_step = None | ||||
@@ -123,7 +129,7 @@ class Trainer(object): | |||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -147,6 +153,8 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||||
callbacks=callbacks) | |||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -176,14 +184,15 @@ class Trainer(object): | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -209,8 +218,9 @@ class Trainer(object): | |||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
print( | |||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf), ) | |||||
results['best_eval'] = self.best_dev_perf | results['best_eval'] = self.best_dev_perf | ||||
results['best_epoch'] = self.best_dev_epoch | results['best_epoch'] = self.best_dev_epoch | ||||
results['best_step'] = self.best_dev_step | results['best_step'] = self.best_dev_step | ||||
@@ -241,7 +251,7 @@ class Trainer(object): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
for epoch in range(1, self.n_epochs+1): | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
self.epoch = epoch | self.epoch = epoch | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -256,8 +266,9 @@ class Trainer(object): | |||||
# edit prediction | # edit prediction | ||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y) | |||||
loss = self._compute_loss(prediction, batch_y).mean() | |||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss / self.update_every | |||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
@@ -288,7 +299,7 @@ class Trainer(object): | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str) | |||||
pbar.write(eval_str + '\n') | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
@@ -303,17 +314,19 @@ class Trainer(object): | |||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | |||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
else: | else: | ||||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | ||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
is_better_eval = True | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key) | |||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -332,7 +345,8 @@ class Trainer(object): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
self.optimizer.step() | |||||
if (self.step + 1) % self.update_every == 0: | |||||
self.optimizer.step() | |||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
@@ -348,7 +362,8 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
self.model.zero_grad() | |||||
if self.step % self.update_every == 0: | |||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
@@ -423,6 +438,7 @@ class Trainer(object): | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
def _get_value_info(_dict): | def _get_value_info(_dict): | ||||
# given a dict value, return information about this dict's value. Return list of str | # given a dict value, return information about this dict's value. Return list of str | ||||
strs = [] | strs = [] | ||||
@@ -439,6 +455,7 @@ def _get_value_info(_dict): | |||||
strs.append(_str) | strs.append(_str) | ||||
return strs | return strs | ||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
check_level=0): | check_level=0): | ||||
@@ -449,17 +466,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
if batch_count==0: | |||||
if batch_count == 0: | |||||
info_str = "" | info_str = "" | ||||
input_fields = _get_value_info(batch_x) | input_fields = _get_value_info(batch_x) | ||||
target_fields = _get_value_info(batch_y) | target_fields = _get_value_info(batch_y) | ||||
if len(input_fields)>0: | |||||
if len(input_fields) > 0: | |||||
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | ||||
info_str += "\n".join(input_fields) | info_str += "\n".join(input_fields) | ||||
info_str += '\n' | info_str += '\n' | ||||
else: | else: | ||||
raise RuntimeError("There is no input field.") | raise RuntimeError("There is no input field.") | ||||
if len(target_fields)>0: | |||||
if len(target_fields) > 0: | |||||
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | ||||
info_str += "\n".join(target_fields) | info_str += "\n".join(target_fields) | ||||
info_str += '\n' | info_str += '\n' | ||||
@@ -467,7 +484,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
info_str += 'There is no target field.' | info_str += 'There is no target field.' | ||||
print(info_str) | print(info_str) | ||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | _check_forward_error(forward_func=model.forward, dataset=dataset, | ||||
batch_x=batch_x, check_level=check_level) | |||||
batch_x=batch_x, check_level=check_level) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
pred_dict = model(**refined_batch_x) | pred_dict = model(**refined_batch_x) | ||||
@@ -11,6 +11,64 @@ import torch | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
def _prepare_cache_filepath(filepath): | |||||
""" | |||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||||
:param filepath: str. | |||||
:return: None, if not, this function will raise error | |||||
""" | |||||
_cache_filepath = os.path.abspath(filepath) | |||||
if os.path.isdir(_cache_filepath): | |||||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||||
cache_dir = os.path.dirname(_cache_filepath) | |||||
if not os.path.exists(cache_dir): | |||||
os.makedirs(cache_dir) | |||||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||||
def wrapper_(func): | |||||
signature = inspect.signature(func) | |||||
for key, _ in signature.parameters.items(): | |||||
if key in ('cache_filepath', 'refresh', 'verbose'): | |||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||||
def wrapper(*args, **kwargs): | |||||
if 'cache_filepath' in kwargs: | |||||
_cache_filepath = kwargs.pop('cache_filepath') | |||||
assert isinstance(_cache_filepath, str), "cache_filepath can only be str." | |||||
else: | |||||
_cache_filepath = cache_filepath | |||||
if 'refresh' in kwargs: | |||||
_refresh = kwargs.pop('refresh') | |||||
assert isinstance(_refresh, bool), "refresh can only be bool." | |||||
else: | |||||
_refresh = refresh | |||||
if 'verbose' in kwargs: | |||||
_verbose = kwargs.pop('verbose') | |||||
assert isinstance(_verbose, int), "verbose can only be integer." | |||||
refresh_flag = True | |||||
if _cache_filepath is not None and _refresh is False: | |||||
# load data | |||||
if os.path.exists(_cache_filepath): | |||||
with open(_cache_filepath, 'rb') as f: | |||||
results = _pickle.load(f) | |||||
if verbose==1: | |||||
print("Read cache from {}.".format(_cache_filepath)) | |||||
refresh_flag = False | |||||
if refresh_flag: | |||||
results = func(*args, **kwargs) | |||||
if _cache_filepath is not None: | |||||
if results is None: | |||||
raise RuntimeError("The return value is None. Delete the decorator.") | |||||
_prepare_cache_filepath(_cache_filepath) | |||||
with open(_cache_filepath, 'wb') as f: | |||||
_pickle.dump(results, f) | |||||
print("Save cache to {}.".format(_cache_filepath)) | |||||
return results | |||||
return wrapper | |||||
return wrapper_ | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -139,17 +197,22 @@ def get_func_signature(func): | |||||
Given a function or method, return its signature. | Given a function or method, return its signature. | ||||
For example: | For example: | ||||
(1) function | |||||
1 function:: | |||||
def func(a, b='a', *args): | def func(a, b='a', *args): | ||||
xxxx | xxxx | ||||
get_func_signature(func) # 'func(a, b='a', *args)' | get_func_signature(func) # 'func(a, b='a', *args)' | ||||
(2) method | |||||
2 method:: | |||||
class Demo: | class Demo: | ||||
def __init__(self): | def __init__(self): | ||||
xxx | xxx | ||||
def forward(self, a, b='a', **args) | def forward(self, a, b='a', **args) | ||||
demo = Demo() | demo = Demo() | ||||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | ||||
:param func: a function or a method | :param func: a function or a method | ||||
:return: str or None | :return: str or None | ||||
""" | """ | ||||
@@ -1,5 +1,5 @@ | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.dataset import DataSet | |||||
def check_build_vocab(func): | def check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
@@ -151,6 +151,77 @@ class Vocabulary(object): | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@check_build_vocab | |||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||||
""" | |||||
example: | |||||
# remember to use `field_name` | |||||
vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') | |||||
:param datasets: fastNLP Dataset type. you can pass multiple datasets | |||||
:param field_name: str, what field to index. Only support 0,1,2 dimension. | |||||
:param new_field_name: str. What the indexed field should be named, default is to overwrite field_name | |||||
:return: | |||||
""" | |||||
def index_instance(ins): | |||||
""" | |||||
有几种情况, str, 1d-list, 2d-list | |||||
:param ins: | |||||
:return: | |||||
""" | |||||
field = ins[field_name] | |||||
if isinstance(field, str): | |||||
return self.to_index(field) | |||||
elif isinstance(field, list): | |||||
if not isinstance(field[0], list): | |||||
return [self.to_index(w) for w in field] | |||||
else: | |||||
if isinstance(field[0][0], list): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | |||||
return[[self.to_index(c) for c in w] for w in field] | |||||
if new_field_name is None: | |||||
new_field_name = field_name | |||||
for idx, dataset in enumerate(datasets): | |||||
if isinstance(dataset, DataSet): | |||||
try: | |||||
dataset.apply(index_instance, new_field_name=new_field_name) | |||||
except Exception as e: | |||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
raise e | |||||
else: | |||||
raise RuntimeError("Only DataSet type is allowed.") | |||||
def from_dataset(self, *datasets, field_name): | |||||
""" | |||||
Construct vocab from dataset. | |||||
:param datasets: DataSet. | |||||
:param field_name: str, what field is used to construct dataset. | |||||
:return: | |||||
""" | |||||
def construct_vocab(ins): | |||||
field = ins[field_name] | |||||
if isinstance(field, str): | |||||
self.add_word(field) | |||||
elif isinstance(field, list): | |||||
if not isinstance(field[0], list): | |||||
self.add_word_lst(field) | |||||
else: | |||||
if isinstance(field[0][0], list): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | |||||
[self.add_word_lst(w) for w in field] | |||||
for idx, dataset in enumerate(datasets): | |||||
if isinstance(dataset, DataSet): | |||||
try: | |||||
dataset.apply(construct_vocab) | |||||
except Exception as e: | |||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
raise e | |||||
else: | |||||
raise RuntimeError("Only DataSet type is allowed.") | |||||
return self | |||||
def to_index(self, w): | def to_index(self, w): | ||||
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label. | """ Turn a word to an index. If w is not in Vocabulary, return the unknown label. | ||||
@@ -0,0 +1 @@ | |||||
from .embed_loader import EmbedLoader |
@@ -26,10 +26,10 @@ class ConfigLoader(BaseLoader): | |||||
:param str file_path: the path of config file | :param str file_path: the path of config file | ||||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | :param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | ||||
Example:: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
Example:: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
""" | """ | ||||
assert isinstance(sections, dict) | assert isinstance(sections, dict) | ||||
@@ -1,9 +1,12 @@ | |||||
import os | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import BaseLoader | from fastNLP.io.base_loader import BaseLoader | ||||
import warnings | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
"""docstring for EmbedLoader""" | """docstring for EmbedLoader""" | ||||
@@ -124,3 +127,137 @@ class EmbedLoader(BaseLoader): | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | size=(len(vocab) - np.sum(hit_flags), emb_dim)) | ||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | ||||
return embedding_matrix | return embedding_matrix | ||||
@staticmethod | |||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | |||||
""" | |||||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param vocab: Vocabulary. | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||||
raise | |||||
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||||
embedding | |||||
""" | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||||
if not os.path.exists(embed_filepath): | |||||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||||
line = f.readline().strip() | |||||
parts = line.split() | |||||
start_idx = 0 | |||||
if len(parts)==2: | |||||
dim = int(parts[1]) | |||||
start_idx += 1 | |||||
else: | |||||
dim = len(parts)-1 | |||||
f.seek(0) | |||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||||
for idx, line in enumerate(f, start_idx): | |||||
try: | |||||
parts = line.strip().split() | |||||
if parts[0] in vocab: | |||||
index = vocab.to_index(parts[0]) | |||||
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||||
hit_flags[index] = True | |||||
except Exception as e: | |||||
if error == 'ignore': | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||||
else: | |||||
raise e | |||||
total_hits = sum(hit_flags) | |||||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||||
found_vectors = matrix[hit_flags] | |||||
if len(found_vectors)!=0: | |||||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||||
std = np.std(found_vectors, axis=0, keepdims=True) | |||||
unfound_vec_num = len(vocab) - total_hits | |||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean | |||||
matrix[hit_flags==False] = r_vecs | |||||
if normalize: | |||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||||
return matrix | |||||
@staticmethod | |||||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||||
error='ignore'): | |||||
""" | |||||
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param padding: the padding tag for vocabulary. | |||||
:param unknown: the unknown tag for vocabulary. | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||||
:raise | |||||
:return: np.ndarray() is determined by the pretraining embeddings | |||||
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||||
""" | |||||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||||
vec_dict = {} | |||||
found_unknown = False | |||||
found_pad = False | |||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||||
line = f.readline() | |||||
start = 1 | |||||
dim = -1 | |||||
if len(line.strip().split())!=2: | |||||
f.seek(0) | |||||
start = 0 | |||||
for idx, line in enumerate(f, start=start): | |||||
try: | |||||
parts = line.strip().split() | |||||
word = parts[0] | |||||
if dim==-1: | |||||
dim = len(parts)-1 | |||||
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||||
vec_dict[word] = vec | |||||
vocab.add_word(word) | |||||
if unknown is not None and unknown==word: | |||||
found_unknown = True | |||||
if found_pad is not None and padding==word: | |||||
found_pad = True | |||||
except Exception as e: | |||||
if error=='ignore': | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||||
pass | |||||
else: | |||||
raise e | |||||
if dim==-1: | |||||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||||
# TODO 需要保证unk其它数据同分布的吗? | |||||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | |||||
start_idx = 0 | |||||
if padding is not None: | |||||
start_idx += 1 | |||||
if unknown is not None: | |||||
start_idx += 1 | |||||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | |||||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | |||||
if (unknown is not None and not found_unknown): | |||||
matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean | |||||
if (padding is not None and not found_pad): | |||||
matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean | |||||
for key, vec in vec_dict.items(): | |||||
index = vocab.to_index(key) | |||||
matrix[index] = vec | |||||
if normalize: | |||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||||
return matrix, vocab |
@@ -1,35 +0,0 @@ | |||||
import logging | |||||
import os | |||||
def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | |||||
"""Create a logger. | |||||
:param str logger_name: | |||||
:param str log_path: | |||||
:param log_format: | |||||
:param log_level: | |||||
:return: logger | |||||
To use a logger:: | |||||
logger.debug("this is a debug message") | |||||
logger.info("this is a info message") | |||||
logger.warning("this is a warning message") | |||||
logger.error("this is an error message") | |||||
""" | |||||
logger = logging.getLogger(logger_name) | |||||
logger.setLevel(log_level) | |||||
if log_path is None: | |||||
handler = logging.StreamHandler() | |||||
else: | |||||
os.stat(os.path.dirname(os.path.abspath(log_path))) | |||||
handler = logging.FileHandler(log_path) | |||||
handler.setLevel(log_level) | |||||
if log_format is None: | |||||
log_format = "[%(asctime)s %(name)-13s %(levelname)s %(process)d %(thread)d " \ | |||||
"%(filename)s:%(lineno)-5d] %(message)s" | |||||
formatter = logging.Formatter(log_format) | |||||
handler.setFormatter(formatter) | |||||
logger.addHandler(handler) | |||||
return logger |
@@ -31,16 +31,18 @@ class ModelLoader(BaseLoader): | |||||
class ModelSaver(object): | class ModelSaver(object): | ||||
"""Save a model | """Save a model | ||||
Example:: | |||||
:param str save_path: the path to the saving directory. | |||||
Example:: | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | |||||
:param save_path: the path to the saving directory. | |||||
""" | |||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model, param_only=True): | def save_pytorch(self, model, param_only=True): | ||||
@@ -20,16 +20,23 @@ class Highway(nn.Module): | |||||
class CharLM(nn.Module): | class CharLM(nn.Module): | ||||
"""CNN + highway network + LSTM | """CNN + highway network + LSTM | ||||
# Input: | |||||
# Input:: | |||||
4D tensor with shape [batch_size, in_channel, height, width] | 4D tensor with shape [batch_size, in_channel, height, width] | ||||
# Output: | |||||
# Output:: | |||||
2D Tensor with shape [batch_size, vocab_size] | 2D Tensor with shape [batch_size, vocab_size] | ||||
# Arguments: | |||||
# Arguments:: | |||||
char_emb_dim: the size of each character's attention | char_emb_dim: the size of each character's attention | ||||
word_emb_dim: the size of each word's attention | word_emb_dim: the size of each word's attention | ||||
vocab_size: num of unique words | vocab_size: num of unique words | ||||
num_char: num of characters | num_char: num of characters | ||||
use_gpu: True or False | use_gpu: True or False | ||||
""" | """ | ||||
def __init__(self, char_emb_dim, word_emb_dim, | def __init__(self, char_emb_dim, word_emb_dim, | ||||
@@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): | |||||
""" | """ | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | """ | ||||
results = {} | results = {} | ||||
@@ -1,6 +1,5 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder as Decoder | from fastNLP.modules import decoder as Decoder | ||||
@@ -40,7 +39,7 @@ class ESIM(BaseModel): | |||||
batch_first=self.batch_first, bidirectional=True | batch_first=self.batch_first, bidirectional=True | ||||
) | ) | ||||
self.bi_attention = Aggregator.Bi_Attention() | |||||
self.bi_attention = Aggregator.BiAttention() | |||||
self.mean_pooling = Aggregator.MeanPoolWithMask() | self.mean_pooling = Aggregator.MeanPoolWithMask() | ||||
self.max_pooling = Aggregator.MaxPoolWithMask() | self.max_pooling = Aggregator.MaxPoolWithMask() | ||||
@@ -53,23 +52,23 @@ class ESIM(BaseModel): | |||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | ||||
def forward(self, premise, hypothesis, premise_len, hypothesis_len): | |||||
def forward(self, words1, words2, seq_len1, seq_len2): | |||||
""" Forward function | """ Forward function | ||||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||||
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | |||||
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | |||||
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||||
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||||
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. | |||||
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. | |||||
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | ||||
""" | """ | ||||
premise0 = self.embedding_layer(self.embedding(premise)) | |||||
hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) | |||||
premise0 = self.embedding_layer(self.embedding(words1)) | |||||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | |||||
_BP, _PSL, _HP = premise0.size() | _BP, _PSL, _HP = premise0.size() | ||||
_BH, _HSL, _HH = hypothesis0.size() | _BH, _HSL, _HH = hypothesis0.size() | ||||
_BPL, _PLL = premise_len.size() | |||||
_HPL, _HLL = hypothesis_len.size() | |||||
_BPL, _PLL = seq_len1.size() | |||||
_HPL, _HLL = seq_len2.size() | |||||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | assert _BP == _BH and _BPL == _HPL and _BP == _BPL | ||||
assert _HP == _HH | assert _HP == _HH | ||||
@@ -84,7 +83,7 @@ class ESIM(BaseModel): | |||||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | ||||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | ||||
ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) | |||||
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | |||||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ||||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | ||||
@@ -98,17 +97,18 @@ class ESIM(BaseModel): | |||||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | ||||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | ||||
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | |||||
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | |||||
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | |||||
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] | |||||
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | |||||
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | |||||
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | |||||
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | |||||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | ||||
prediction = F.tanh(self.output(v)) # prediction: [B, N] | |||||
prediction = torch.tanh(self.output(v)) # prediction: [B, N] | |||||
return {'pred': prediction} | return {'pred': prediction} | ||||
def predict(self, premise, hypothesis, premise_len, hypothesis_len): | |||||
return self.forward(premise, hypothesis, premise_len, hypothesis_len) | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | |||||
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] | |||||
return torch.argmax(prediction, dim=-1) | |||||
@@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask | |||||
from .kmax_pool import KMaxPool | from .kmax_pool import KMaxPool | ||||
from .attention import Attention | from .attention import Attention | ||||
from .attention import Bi_Attention | |||||
from .attention import BiAttention | |||||
from .self_attention import SelfAttention | from .self_attention import SelfAttention | ||||
@@ -23,9 +23,9 @@ class Attention(torch.nn.Module): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
class DotAttention(nn.Module): | |||||
def __init__(self, key_size, value_size, dropout=0.1): | def __init__(self, key_size, value_size, dropout=0.1): | ||||
super(DotAtte, self).__init__() | |||||
super(DotAttention, self).__init__() | |||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
@@ -48,7 +48,7 @@ class DotAtte(nn.Module): | |||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
class MultiHeadAtte(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | |||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | ||||
""" | """ | ||||
@@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module): | |||||
:param num_head: int,head的数量。 | :param num_head: int,head的数量。 | ||||
:param dropout: float。 | :param dropout: float。 | ||||
""" | """ | ||||
super(MultiHeadAtte, self).__init__() | |||||
super(MultiHeadAttention, self).__init__() | |||||
self.input_size = input_size | self.input_size = input_size | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
@@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module): | |||||
self.q_in = nn.Linear(input_size, in_size) | self.q_in = nn.Linear(input_size, in_size) | ||||
self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
self.v_in = nn.Linear(input_size, in_size) | self.v_in = nn.Linear(input_size, in_size) | ||||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||||
self.attention = DotAttention(key_size=key_size, value_size=value_size) | |||||
self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
self.drop = TimestepDropout(dropout) | self.drop = TimestepDropout(dropout) | ||||
self.reset_parameters() | self.reset_parameters() | ||||
@@ -109,16 +109,34 @@ class MultiHeadAtte(nn.Module): | |||||
return output | return output | ||||
class Bi_Attention(nn.Module): | |||||
class BiAttention(nn.Module): | |||||
"""Bi Attention module | |||||
Calculate Bi Attention matrix `e` | |||||
.. math:: | |||||
\begin{array}{ll} \\ | |||||
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ | |||||
a_i = | |||||
b_j = | |||||
\end{array} | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(Bi_Attention, self).__init__() | |||||
super(BiAttention, self).__init__() | |||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, in_x1, in_x2, x1_len, x2_len): | def forward(self, in_x1, in_x2, x1_len, x2_len): | ||||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||||
# x1_len: [batch_size, x1_seq_len] | |||||
# x2_len: [batch_size, x2_seq_len] | |||||
""" | |||||
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | |||||
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 | |||||
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 | |||||
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 | |||||
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 | |||||
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | |||||
""" | |||||
assert in_x1.size()[0] == in_x2.size()[0] | assert in_x1.size()[0] == in_x2.size()[0] | ||||
assert in_x1.size()[2] == in_x2.size()[2] | assert in_x1.size()[2] == in_x2.size()[2] | ||||
@@ -36,6 +36,7 @@ class MLP(nn.Module): | |||||
actives = { | actives = { | ||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
'sigmoid': nn.Sigmoid(), | |||||
} | } | ||||
if not isinstance(activation, list): | if not isinstance(activation, list): | ||||
activation = [activation] * (len(size_layer) - 2) | activation = [activation] * (len(size_layer) - 2) | ||||
@@ -1,6 +1,6 @@ | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..aggregator.attention import MultiHeadAttention | |||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
@@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||||
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size) | self.norm1 = nn.LayerNorm(model_size) | ||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | ||||
nn.ReLU(), | nn.ReLU(), | ||||
@@ -8,7 +8,7 @@ | |||||
## Star-Transformer | ## Star-Transformer | ||||
[reference](https://arxiv.org/abs/1902.09113) | [reference](https://arxiv.org/abs/1902.09113) | ||||
### Performance | |||||
### Performance (still in progress) | |||||
|任务| 数据集 | SOTA | 模型表现 | | |任务| 数据集 | SOTA | 模型表现 | | ||||
|------|------| ------| ------| | |------|------| ------| ------| | ||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |Pos Tagging|CTB 9.0|-|ACC 92.31| | ||||
@@ -13,12 +13,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
setup( | setup( | ||||
name='FastNLP', | name='FastNLP', | ||||
version='0.1.1', | |||||
version='0.4.0', | |||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
author='FudanNLP', | author='FudanNLP', | ||||
python_requires='>=3.5', | |||||
python_requires='>=3.6', | |||||
packages=find_packages(), | packages=find_packages(), | ||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
) | ) |
@@ -35,7 +35,7 @@ class TestENAS(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -125,7 +125,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3) | |||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | |||||
self.assertEqual(len(ds), 20) | self.assertEqual(len(ds), 20) | ||||
def test_contains(self): | def test_contains(self): | ||||
@@ -169,7 +169,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | ||||
sep='\t') | sep='\t') | ||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0) | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
@@ -202,11 +202,11 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(ans, FieldArray)) | self.assertTrue(isinstance(ans, FieldArray)) | ||||
self.assertEqual(ans.content, [[5, 6]] * 10) | self.assertEqual(ans.content, [[5, 6]] * 10) | ||||
# def test_add_null(self): | |||||
# # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||||
# ds = DataSet() | |||||
# ds.add_field('test', []) | |||||
# ds.set_target('test') | |||||
def test_add_null(self): | |||||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||||
ds = DataSet() | |||||
with self.assertRaises(RuntimeError) as RE: | |||||
ds.add_field('test', []) | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
@@ -15,7 +15,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_AccuracyMetric2(self): | def test_AccuracyMetric2(self): | ||||
@@ -30,7 +30,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | |||||
print("No exception catches.") | |||||
def test_AccuracyMetric3(self): | def test_AccuracyMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
@@ -95,10 +95,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | self.assertAlmostEqual(res["acc"], float(ans), places=4) | ||||
def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
@@ -0,0 +1,115 @@ | |||||
import unittest | |||||
import _pickle | |||||
from fastNLP import cache_results | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
import time | |||||
import os | |||||
@cache_results('test/demo1.pkl') | |||||
def process_data_1(embed_file, cws_train): | |||||
embed, vocab = EmbedLoader.load_without_vocab(embed_file) | |||||
time.sleep(1) # 测试是否通过读取cache获得结果 | |||||
with open(cws_train, 'r', encoding='utf-8') as f: | |||||
d = DataSet() | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line)>0: | |||||
d.append(Instance(raw=line)) | |||||
return embed, vocab, d | |||||
class TestCache(unittest.TestCase): | |||||
def test_cache_save(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo1.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(pre_time-0.5, read_time) | |||||
finally: | |||||
os.remove('test/demo1.pkl') | |||||
def test_cache_save_overwrite_path(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
cache_filepath='test/demo_overwrite.pkl') | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo_overwrite.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
cache_filepath='test/demo_overwrite.pkl') | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(pre_time-0.5, read_time) | |||||
finally: | |||||
os.remove('test/demo_overwrite.pkl') | |||||
def test_cache_refresh(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
refresh=True) | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo1.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
refresh=True) | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(0.1, pre_time-read_time) | |||||
finally: | |||||
os.remove('test/demo1.pkl') | |||||
def test_duplicate_keyword(self): | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_verbose(a, verbose): | |||||
pass | |||||
func_verbose(0, 1) | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_cache(a, cache_filepath): | |||||
pass | |||||
func_cache(1, 2) | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_refresh(a, refresh): | |||||
pass | |||||
func_refresh(1, 2) | |||||
def test_create_cache_dir(self): | |||||
@cache_results('test/demo1/demo.pkl') | |||||
def cache(): | |||||
return 1, 2 | |||||
try: | |||||
results = cache() | |||||
print(results) | |||||
finally: | |||||
os.remove('test/demo1/demo.pkl') | |||||
os.rmdir('test/demo1') |
@@ -2,6 +2,8 @@ import unittest | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | ||||
"works", "well", "in", "most", "cases", "scales", "well"] | "works", "well", "in", "most", "cases", "scales", "well"] | ||||
@@ -31,6 +33,42 @@ class TestAdd(unittest.TestCase): | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_from_dataset(self): | |||||
start_char = 65 | |||||
num_samples = 10 | |||||
# 0 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=chr(start_char+i)) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
# 1 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=[chr(start_char+i)]*6) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
# 2 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
class TestIndexing(unittest.TestCase): | class TestIndexing(unittest.TestCase): | ||||
def test_len(self): | def test_len(self): | ||||
@@ -6,7 +6,7 @@ from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "test/io/" | |||||
config_file_dir = "test/io" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -1,4 +1,5 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
@@ -10,3 +11,35 @@ class TestEmbedLoader(unittest.TestCase): | |||||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | ||||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | ||||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | ||||
def test_load_with_vocab(self): | |||||
vocab = Vocabulary() | |||||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||||
vocab.add_word('the') | |||||
vocab.add_word('none') | |||||
g_m = EmbedLoader.load_with_vocab(glove, vocab) | |||||
self.assertEqual(g_m.shape, (4, 50)) | |||||
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) | |||||
self.assertEqual(w_m.shape, (4, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4) | |||||
def test_load_without_vocab(self): | |||||
words = ['the', 'of', 'in', 'a', 'to', 'and'] | |||||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||||
g_m, vocab = EmbedLoader.load_without_vocab(glove) | |||||
self.assertEqual(g_m.shape, (8, 50)) | |||||
for word in words: | |||||
self.assertIn(word, vocab) | |||||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True) | |||||
self.assertEqual(w_m.shape, (8, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8) | |||||
for word in words: | |||||
self.assertIn(word, vocab) | |||||
# no unk | |||||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None) | |||||
self.assertEqual(w_m.shape, (7, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7) | |||||
for word in words: | |||||
self.assertIn(word, vocab) |
@@ -118,7 +118,7 @@ class TestCRF(unittest.TestCase): | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | ||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | crf = ConditionalRandomField(num_tags, include_start_end_trans) | ||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | ||||
for _ in range(10000): | |||||
for _ in range(10): | |||||
loss = crf(feats, tags, masks).mean() | loss = crf(feats, tags, masks).mean() | ||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
@@ -1,9 +0,0 @@ | |||||
import unittest | |||||
class TestUtils(unittest.TestCase): | |||||
def test_case_1(self): | |||||
pass | |||||
def test_case_2(self): | |||||
pass |
@@ -35,7 +35,7 @@ class TestTutorial(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -152,7 +152,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
dev_data=dev_data, | dev_data=dev_data, | ||||
loss=CrossEntropyLoss(), | loss=CrossEntropyLoss(), | ||||
metrics=AccuracyMetric() | |||||
metrics=AccuracyMetric(target='label_seq') | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||
@@ -296,7 +296,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 筛选数据 | # 筛选数据 | ||||
origin_data_set_len = len(data_set) | origin_data_set_len = len(data_set) | ||||
data_set.drop(lambda x: len(x['premise']) <= 6) | |||||
data_set.drop(lambda x: len(x['premise']) <= 6, inplace=True) | |||||
origin_data_set_len, len(data_set) | origin_data_set_len, len(data_set) | ||||
# In[17]: | # In[17]: | ||||
@@ -407,7 +407,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
model=model, | model=model, | ||||
loss=CrossEntropyLoss(pred='pred', target='label'), | loss=CrossEntropyLoss(pred='pred', target='label'), | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
n_epochs=3, | n_epochs=3, | ||||
batch_size=16, | batch_size=16, | ||||
print_every=-1, | print_every=-1, | ||||
@@ -424,7 +424,7 @@ class TestTutorial(unittest.TestCase): | |||||
tester = Tester( | tester = Tester( | ||||
data=test_data, | data=test_data, | ||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
batch_size=args["batch_size"], | batch_size=args["batch_size"], | ||||
) | ) | ||||
tester.test() | tester.test() | ||||
@@ -20,16 +20,7 @@ | |||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 1, | "execution_count": 1, | ||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||||
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | |||||
] | |||||
} | |||||
], | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 声明部件\n", | "# 声明部件\n", | ||||
"import torch\n", | "import torch\n", | ||||
@@ -179,11 +170,11 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"DataSet({'image': tensor([[ 2.1747, -1.0147, -1.3853, 0.0216, -0.4957],\n", | |||||
" [ 0.8138, -0.2933, -0.1217, -0.6027, 0.3932],\n", | |||||
" [ 0.6750, -1.1136, -1.3371, -0.0185, -0.3206],\n", | |||||
" [-0.5076, -0.3822, 0.1719, -0.6447, -0.5702],\n", | |||||
" [ 0.3804, 0.0889, 0.8027, -0.7121, -0.7320]]) type=torch.Tensor,\n", | |||||
"DataSet({'image': tensor([[ 4.7106e-01, -1.2246e+00, 3.1234e-01, -1.6781e+00, -8.7967e-01],\n", | |||||
" [ 1.1454e+00, 1.2236e-01, 3.0258e-01, -1.5454e+00, 8.9201e-01],\n", | |||||
" [-5.7143e-03, 3.9488e-01, 2.0287e-01, -1.5726e+00, 9.3171e-01],\n", | |||||
" [ 6.8914e-01, -2.6302e-01, -8.2694e-01, 9.5942e-01, -5.2589e-01],\n", | |||||
" [-5.7798e-03, -9.1621e-03, 1.0077e-03, 9.1716e-02, 1.0565e+00]]) type=torch.Tensor,\n", | |||||
"'label': 0 type=int})" | "'label': 0 type=int})" | ||||
] | ] | ||||
}, | }, | ||||
@@ -644,20 +635,20 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"({'premise': [2, 145, 146, 80, 147, 26, 148, 2, 104, 149, 150, 2, 151, 5, 55, 152, 105, 3] type=list,\n", | |||||
" 'hypothesis': [22, 80, 8, 1, 1, 20, 1, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 2 type=int},\n", | |||||
" {'premise': [11, 5, 18, 5, 24, 6, 2, 10, 59, 52, 14, 9, 2, 53, 29, 60, 54, 45, 6, 46, 5, 7, 61, 3] type=list,\n", | |||||
" 'hypothesis': [22, 11, 1, 45, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n", | |||||
"({'premise': [2, 10, 9, 2, 15, 115, 6, 11, 5, 132, 17, 2, 76, 9, 77, 55, 3] type=list,\n", | |||||
" 'hypothesis': [1, 2, 56, 17, 1, 4, 13, 49, 123, 12, 6, 11, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 0 type=int},\n", | |||||
" {'premise': [50, 124, 10, 7, 68, 91, 92, 38, 2, 55, 3] type=list,\n", | |||||
" 'hypothesis': [21, 10, 5, 2, 55, 7, 99, 64, 48, 1, 22, 1, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 1 type=int},\n", | " 'label': 1 type=int},\n", | ||||
" {'premise': [2, 11, 8, 14, 16, 7, 15, 50, 2, 66, 4, 76, 2, 10, 8, 98, 9, 58, 67, 3] type=list,\n", | |||||
" 'hypothesis': [22, 27, 50, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1] type=list,\n", | |||||
" {'premise': [13, 24, 4, 14, 29, 5, 25, 4, 8, 39, 9, 14, 34, 4, 40, 41, 4, 16, 12, 2, 11, 4, 30, 28, 2, 42, 8, 2, 43, 44, 17, 2, 45, 35, 26, 31, 27, 5, 6, 32, 3] type=list,\n", | |||||
" 'hypothesis': [37, 49, 123, 30, 28, 2, 55, 12, 2, 11, 3] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 0 type=int})" | " 'label': 0 type=int})" | ||||
] | ] | ||||
}, | }, | ||||
@@ -718,15 +709,15 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"({'premise': [1037, 2210, 2223, 2136, 5363, 2000, 4608, 1037, 5479, 8058, 2046, 1037, 2918, 1999, 2019, 5027, 2208, 1012] type=list,\n", | |||||
" 'hypothesis': [100, 2136, 2003, 2652, 3598, 2006, 100, 1012] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 2 type=int},\n", | |||||
" {'premise': [2450, 1999, 2317, 1999, 100, 1998, 1037, 2158, 3621, 2369, 3788, 2007, 1037, 3696, 2005, 2198, 100, 10733, 1998, 100, 1999, 1996, 4281, 1012] type=list,\n", | |||||
" 'hypothesis': [100, 2450, 13063, 10733, 1012] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n", | |||||
"({'premise': [1037, 2158, 1998, 1037, 2450, 2892, 1996, 2395, 1999, 2392, 1997, 1037, 10733, 1998, 100, 4825, 1012] type=list,\n", | |||||
" 'hypothesis': [100, 1037, 3232, 1997, 7884, 1010, 2048, 2111, 3328, 2408, 1996, 2395, 1012] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 0 type=int},\n", | |||||
" {'premise': [2019, 3080, 2158, 2003, 5948, 4589, 10869, 2012, 1037, 4825, 1012] type=list,\n", | |||||
" 'hypothesis': [100, 2158, 1999, 1037, 4825, 2003, 3403, 2005, 2010, 7954, 2000, 7180, 1012] type=list,\n", | |||||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n", | |||||
" 'label': 1 type=int})" | " 'label': 1 type=int})" | ||||
] | ] | ||||
}, | }, | ||||
@@ -769,7 +760,7 @@ | |||||
" 'num_classes': 3,\n", | " 'num_classes': 3,\n", | ||||
" 'gpu': True,\n", | " 'gpu': True,\n", | ||||
" 'batch_size': 32,\n", | " 'batch_size': 32,\n", | ||||
" 'vocab_size': 165}" | |||||
" 'vocab_size': 156}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 26, | "execution_count": 26, | ||||
@@ -797,7 +788,7 @@ | |||||
"ESIM(\n", | "ESIM(\n", | ||||
" (drop): Dropout(p=0.3)\n", | " (drop): Dropout(p=0.3)\n", | ||||
" (embedding): Embedding(\n", | " (embedding): Embedding(\n", | ||||
" (embed): Embedding(165, 300, padding_idx=0)\n", | |||||
" (embed): Embedding(156, 300, padding_idx=0)\n", | |||||
" (dropout): Dropout(p=0.3)\n", | " (dropout): Dropout(p=0.3)\n", | ||||
" )\n", | " )\n", | ||||
" (embedding_layer): Linear(\n", | " (embedding_layer): Linear(\n", | ||||
@@ -821,7 +812,6 @@ | |||||
" )\n", | " )\n", | ||||
" (output): Linear(in_features=300, out_features=3, bias=True)\n", | " (output): Linear(in_features=300, out_features=3, bias=True)\n", | ||||
" (dropout): Dropout(p=0.3)\n", | " (dropout): Dropout(p=0.3)\n", | ||||
" (hidden_active): Tanh()\n", | |||||
" )\n", | " )\n", | ||||
")" | ")" | ||||
] | ] | ||||
@@ -848,7 +838,7 @@ | |||||
"text/plain": [ | "text/plain": [ | ||||
"CNNText(\n", | "CNNText(\n", | ||||
" (embed): Embedding(\n", | " (embed): Embedding(\n", | ||||
" (embed): Embedding(165, 50, padding_idx=0)\n", | |||||
" (embed): Embedding(156, 50, padding_idx=0)\n", | |||||
" (dropout): Dropout(p=0.0)\n", | " (dropout): Dropout(p=0.0)\n", | ||||
" )\n", | " )\n", | ||||
" (conv_pool): ConvMaxpool(\n", | " (conv_pool): ConvMaxpool(\n", | ||||
@@ -1019,43 +1009,49 @@ | |||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"training epochs started 2019-01-09 00-08-17\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.206897\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", | |||||
" warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.206897\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.206897\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.206897\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.206897\n", | |||||
"training epochs started 2019-04-14-23-22-28\n", | |||||
"[epoch: 1 step: 1] train loss: 1.51372 time: 0:00:00\n", | |||||
"[epoch: 1 step: 2] train loss: 1.26874 time: 0:00:00\n", | |||||
"[epoch: 1 step: 3] train loss: 1.49786 time: 0:00:00\n", | |||||
"[epoch: 1 step: 4] train loss: 1.37505 time: 0:00:00\n", | |||||
"Evaluation at Epoch 1/5. Step:4/20. AccuracyMetric: acc=0.344828\n", | |||||
"\n", | |||||
"[epoch: 2 step: 5] train loss: 1.21877 time: 0:00:00\n", | |||||
"[epoch: 2 step: 6] train loss: 1.14183 time: 0:00:00\n", | |||||
"[epoch: 2 step: 7] train loss: 1.15934 time: 0:00:00\n", | |||||
"[epoch: 2 step: 8] train loss: 1.55148 time: 0:00:00\n", | |||||
"Evaluation at Epoch 2/5. Step:8/20. AccuracyMetric: acc=0.344828\n", | |||||
"\n", | "\n", | ||||
"In Epoch:1/Step:4, got best dev performance:AccuracyMetric: acc=0.206897\n", | |||||
"[epoch: 3 step: 9] train loss: 1.1457 time: 0:00:00\n", | |||||
"[epoch: 3 step: 10] train loss: 1.0547 time: 0:00:00\n", | |||||
"[epoch: 3 step: 11] train loss: 1.40139 time: 0:00:00\n", | |||||
"[epoch: 3 step: 12] train loss: 0.551445 time: 0:00:00\n", | |||||
"Evaluation at Epoch 3/5. Step:12/20. AccuracyMetric: acc=0.275862\n", | |||||
"\n", | |||||
"[epoch: 4 step: 13] train loss: 1.07965 time: 0:00:00\n", | |||||
"[epoch: 4 step: 14] train loss: 1.04118 time: 0:00:00\n", | |||||
"[epoch: 4 step: 15] train loss: 1.11719 time: 0:00:00\n", | |||||
"[epoch: 4 step: 16] train loss: 1.09861 time: 0:00:00\n", | |||||
"Evaluation at Epoch 4/5. Step:16/20. AccuracyMetric: acc=0.275862\n", | |||||
"\n", | |||||
"[epoch: 5 step: 17] train loss: 1.10795 time: 0:00:00\n", | |||||
"[epoch: 5 step: 18] train loss: 1.26715 time: 0:00:00\n", | |||||
"[epoch: 5 step: 19] train loss: 1.19875 time: 0:00:00\n", | |||||
"[epoch: 5 step: 20] train loss: 1.09862 time: 0:00:00\n", | |||||
"Evaluation at Epoch 5/5. Step:20/20. AccuracyMetric: acc=0.37931\n", | |||||
"\n", | |||||
"\n", | |||||
"In Epoch:5/Step:20, got best dev performance:AccuracyMetric: acc=0.37931\n", | |||||
"Reloaded the best model.\n" | "Reloaded the best model.\n" | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'best_eval': {'AccuracyMetric': {'acc': 0.206897}},\n", | |||||
" 'best_epoch': 1,\n", | |||||
" 'best_step': 4,\n", | |||||
" 'seconds': 0.79}" | |||||
"{'best_eval': {'AccuracyMetric': {'acc': 0.37931}},\n", | |||||
" 'best_epoch': 5,\n", | |||||
" 'best_step': 20,\n", | |||||
" 'seconds': 0.5}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 29, | "execution_count": 29, | ||||
@@ -1070,8 +1066,8 @@ | |||||
"trainer = Trainer(\n", | "trainer = Trainer(\n", | ||||
" train_data=train_data,\n", | " train_data=train_data,\n", | ||||
" model=model,\n", | " model=model,\n", | ||||
" loss=CrossEntropyLoss(pred='pred', target='label'),\n", | |||||
" metrics=AccuracyMetric(),\n", | |||||
" loss=CrossEntropyLoss(pred='pred', target='label'), # 模型预测值通过'pred'来取得,目标值(ground truth)由'label'取得\n", | |||||
" metrics=AccuracyMetric(target='label'), # 目标值(ground truth)由'label'取得\n", | |||||
" n_epochs=5,\n", | " n_epochs=5,\n", | ||||
" batch_size=16,\n", | " batch_size=16,\n", | ||||
" print_every=-1,\n", | " print_every=-1,\n", | ||||
@@ -1113,13 +1109,13 @@ | |||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"[tester] \n", | "[tester] \n", | ||||
"AccuracyMetric: acc=0.263158\n" | |||||
"AccuracyMetric: acc=0.368421\n" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'AccuracyMetric': {'acc': 0.263158}}" | |||||
"{'AccuracyMetric': {'acc': 0.368421}}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 30, | "execution_count": 30, | ||||
@@ -1131,12 +1127,33 @@ | |||||
"tester = Tester(\n", | "tester = Tester(\n", | ||||
" data=test_data,\n", | " data=test_data,\n", | ||||
" model=model,\n", | " model=model,\n", | ||||
" metrics=AccuracyMetric(),\n", | |||||
" metrics=AccuracyMetric(target='label'),\n", | |||||
" batch_size=args[\"batch_size\"],\n", | " batch_size=args[\"batch_size\"],\n", | ||||
")\n", | ")\n", | ||||
"tester.test()" | "tester.test()" | ||||
] | ] | ||||
}, | }, | ||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | "execution_count": null, | ||||
@@ -1161,7 +1178,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.6.7" | |||||
"version": "3.7.0" | |||||
} | } | ||||
}, | }, | ||||
"nbformat": 4, | "nbformat": 4, | ||||