Browse Source

conflict merge

tags/v0.4.10
yh 6 years ago
parent
commit
7d94d50c14
31 changed files with 685 additions and 375 deletions
  1. +4
    -0
      docs/Makefile
  2. +3
    -3
      docs/source/conf.py
  3. +39
    -13
      docs/source/fastNLP.api.rst
  4. +70
    -28
      docs/source/fastNLP.core.rst
  5. +30
    -18
      docs/source/fastNLP.io.rst
  6. +82
    -14
      docs/source/fastNLP.models.rst
  7. +30
    -12
      docs/source/fastNLP.modules.aggregator.rst
  8. +18
    -6
      docs/source/fastNLP.modules.decoder.rst
  9. +54
    -20
      docs/source/fastNLP.modules.encoder.rst
  10. +25
    -8
      docs/source/fastNLP.modules.rst
  11. +11
    -2
      docs/source/fastNLP.rst
  12. +3
    -0
      fastNLP/api/__init__.py
  13. +15
    -11
      fastNLP/api/api.py
  14. +8
    -7
      fastNLP/automl/enas_trainer.py
  15. +1
    -0
      fastNLP/core/dataset.py
  16. +10
    -6
      fastNLP/core/fieldarray.py
  17. +7
    -8
      fastNLP/core/instance.py
  18. +2
    -2
      fastNLP/core/losses.py
  19. +91
    -74
      fastNLP/core/metrics.py
  20. +71
    -68
      fastNLP/core/trainer.py
  21. +7
    -2
      fastNLP/core/utils.py
  22. +4
    -4
      fastNLP/io/config_io.py
  23. +27
    -20
      fastNLP/io/dataset_loader.py
  24. +1
    -1
      fastNLP/io/embed_loader.py
  25. +7
    -5
      fastNLP/io/model_io.py
  26. +10
    -3
      fastNLP/models/char_language_model.py
  27. +8
    -7
      fastNLP/models/enas_trainer.py
  28. +19
    -19
      fastNLP/models/snli.py
  29. +1
    -1
      fastNLP/modules/aggregator/__init__.py
  30. +25
    -11
      fastNLP/modules/aggregator/attention.py
  31. +2
    -2
      fastNLP/modules/encoder/transformer.py

+ 4
- 0
docs/Makefile View File

@@ -3,6 +3,7 @@


# You can set these variables from the command line. # You can set these variables from the command line.
SPHINXOPTS = SPHINXOPTS =
SPHINXAPIDOC = sphinx-apidoc
SPHINXBUILD = sphinx-build SPHINXBUILD = sphinx-build
SPHINXPROJ = fastNLP SPHINXPROJ = fastNLP
SOURCEDIR = source SOURCEDIR = source
@@ -12,6 +13,9 @@ BUILDDIR = build
help: help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


apidoc:
@$(SPHINXAPIDOC) -f -o source ../fastNLP

.PHONY: help Makefile .PHONY: help Makefile


# Catch-all target: route all unknown targets to Sphinx using the new # Catch-all target: route all unknown targets to Sphinx using the new


+ 3
- 3
docs/source/conf.py View File

@@ -23,9 +23,9 @@ copyright = '2018, xpqiu'
author = 'xpqiu' author = 'xpqiu'


# The short X.Y version # The short X.Y version
version = '0.2'
version = '0.4'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '0.2'
release = '0.4'




# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
@@ -67,7 +67,7 @@ language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path . # This pattern also affects html_static_path and html_extra_path .
exclude_patterns = []
exclude_patterns = ['modules.rst']


# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = 'sphinx'


+ 39
- 13
docs/source/fastNLP.api.rst View File

@@ -1,36 +1,62 @@
fastNLP.api
============
fastNLP.api package
===================


fastNLP.api.api
----------------
Submodules
----------

fastNLP.api.api module
----------------------


.. automodule:: fastNLP.api.api .. automodule:: fastNLP.api.api
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.api.converter
----------------------
fastNLP.api.converter module
----------------------------


.. automodule:: fastNLP.api.converter .. automodule:: fastNLP.api.converter
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.api.model\_zoo
-----------------------
fastNLP.api.examples module
---------------------------


.. automodule:: fastNLP.api.model_zoo
.. automodule:: fastNLP.api.examples
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.api.pipeline
---------------------
fastNLP.api.pipeline module
---------------------------


.. automodule:: fastNLP.api.pipeline .. automodule:: fastNLP.api.pipeline
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.api.processor
----------------------
fastNLP.api.processor module
----------------------------


.. automodule:: fastNLP.api.processor .. automodule:: fastNLP.api.processor
:members: :members:
:undoc-members:
:show-inheritance:

fastNLP.api.utils module
------------------------

.. automodule:: fastNLP.api.utils
:members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.api .. automodule:: fastNLP.api
:members: :members:
:undoc-members:
:show-inheritance:

+ 70
- 28
docs/source/fastNLP.core.rst View File

@@ -1,84 +1,126 @@
fastNLP.core
=============
fastNLP.core package
====================


fastNLP.core.batch
-------------------
Submodules
----------

fastNLP.core.batch module
-------------------------


.. automodule:: fastNLP.core.batch .. automodule:: fastNLP.core.batch
:members: :members:
:undoc-members:
:show-inheritance:

fastNLP.core.callback module
----------------------------


fastNLP.core.dataset
---------------------
.. automodule:: fastNLP.core.callback
:members:
:undoc-members:
:show-inheritance:

fastNLP.core.dataset module
---------------------------


.. automodule:: fastNLP.core.dataset .. automodule:: fastNLP.core.dataset
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.fieldarray
------------------------
fastNLP.core.fieldarray module
------------------------------


.. automodule:: fastNLP.core.fieldarray .. automodule:: fastNLP.core.fieldarray
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.instance
----------------------
fastNLP.core.instance module
----------------------------


.. automodule:: fastNLP.core.instance .. automodule:: fastNLP.core.instance
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.losses
--------------------
fastNLP.core.losses module
--------------------------


.. automodule:: fastNLP.core.losses .. automodule:: fastNLP.core.losses
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.metrics
---------------------
fastNLP.core.metrics module
---------------------------


.. automodule:: fastNLP.core.metrics .. automodule:: fastNLP.core.metrics
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.optimizer
-----------------------
fastNLP.core.optimizer module
-----------------------------


.. automodule:: fastNLP.core.optimizer .. automodule:: fastNLP.core.optimizer
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.predictor
-----------------------
fastNLP.core.predictor module
-----------------------------


.. automodule:: fastNLP.core.predictor .. automodule:: fastNLP.core.predictor
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.sampler
---------------------
fastNLP.core.sampler module
---------------------------


.. automodule:: fastNLP.core.sampler .. automodule:: fastNLP.core.sampler
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.tester
--------------------
fastNLP.core.tester module
--------------------------


.. automodule:: fastNLP.core.tester .. automodule:: fastNLP.core.tester
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.trainer
---------------------
fastNLP.core.trainer module
---------------------------


.. automodule:: fastNLP.core.trainer .. automodule:: fastNLP.core.trainer
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.utils
-------------------
fastNLP.core.utils module
-------------------------


.. automodule:: fastNLP.core.utils .. automodule:: fastNLP.core.utils
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.core.vocabulary
------------------------
fastNLP.core.vocabulary module
------------------------------


.. automodule:: fastNLP.core.vocabulary .. automodule:: fastNLP.core.vocabulary
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.core .. automodule:: fastNLP.core
:members: :members:
:undoc-members:
:show-inheritance:

+ 30
- 18
docs/source/fastNLP.io.rst View File

@@ -1,42 +1,54 @@
fastNLP.io
===========
fastNLP.io package
==================


fastNLP.io.base\_loader
------------------------
Submodules
----------

fastNLP.io.base\_loader module
------------------------------


.. automodule:: fastNLP.io.base_loader .. automodule:: fastNLP.io.base_loader
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.io.config\_io
----------------------
fastNLP.io.config\_io module
----------------------------


.. automodule:: fastNLP.io.config_io .. automodule:: fastNLP.io.config_io
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.io.dataset\_loader
---------------------------
fastNLP.io.dataset\_loader module
---------------------------------


.. automodule:: fastNLP.io.dataset_loader .. automodule:: fastNLP.io.dataset_loader
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.io.embed\_loader
-------------------------
fastNLP.io.embed\_loader module
-------------------------------


.. automodule:: fastNLP.io.embed_loader .. automodule:: fastNLP.io.embed_loader
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.io.logger
------------------

.. automodule:: fastNLP.io.logger
:members:

fastNLP.io.model\_io
---------------------
fastNLP.io.model\_io module
---------------------------


.. automodule:: fastNLP.io.model_io .. automodule:: fastNLP.io.model_io
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.io .. automodule:: fastNLP.io
:members: :members:
:undoc-members:
:show-inheritance:

+ 82
- 14
docs/source/fastNLP.models.rst View File

@@ -1,42 +1,110 @@
fastNLP.models
===============
fastNLP.models package
======================


fastNLP.models.base\_model
---------------------------
Submodules
----------

fastNLP.models.base\_model module
---------------------------------


.. automodule:: fastNLP.models.base_model .. automodule:: fastNLP.models.base_model
:members: :members:
:undoc-members:
:show-inheritance:

fastNLP.models.bert module
--------------------------


fastNLP.models.biaffine\_parser
--------------------------------
.. automodule:: fastNLP.models.bert
:members:
:undoc-members:
:show-inheritance:

fastNLP.models.biaffine\_parser module
--------------------------------------


.. automodule:: fastNLP.models.biaffine_parser .. automodule:: fastNLP.models.biaffine_parser
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.models.char\_language\_model
-------------------------------------
fastNLP.models.char\_language\_model module
-------------------------------------------


.. automodule:: fastNLP.models.char_language_model .. automodule:: fastNLP.models.char_language_model
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.models.cnn\_text\_classification
-----------------------------------------
fastNLP.models.cnn\_text\_classification module
-----------------------------------------------


.. automodule:: fastNLP.models.cnn_text_classification .. automodule:: fastNLP.models.cnn_text_classification
:members: :members:
:undoc-members:
:show-inheritance:

fastNLP.models.enas\_controller module
--------------------------------------

.. automodule:: fastNLP.models.enas_controller
:members:
:undoc-members:
:show-inheritance:

fastNLP.models.enas\_model module
---------------------------------

.. automodule:: fastNLP.models.enas_model
:members:
:undoc-members:
:show-inheritance:


fastNLP.models.sequence\_modeling
----------------------------------
fastNLP.models.enas\_trainer module
-----------------------------------

.. automodule:: fastNLP.models.enas_trainer
:members:
:undoc-members:
:show-inheritance:

fastNLP.models.enas\_utils module
---------------------------------

.. automodule:: fastNLP.models.enas_utils
:members:
:undoc-members:
:show-inheritance:

fastNLP.models.sequence\_modeling module
----------------------------------------


.. automodule:: fastNLP.models.sequence_modeling .. automodule:: fastNLP.models.sequence_modeling
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.models.snli
--------------------
fastNLP.models.snli module
--------------------------


.. automodule:: fastNLP.models.snli .. automodule:: fastNLP.models.snli
:members: :members:
:undoc-members:
:show-inheritance:

fastNLP.models.star\_transformer module
---------------------------------------

.. automodule:: fastNLP.models.star_transformer
:members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.models .. automodule:: fastNLP.models
:members: :members:
:undoc-members:
:show-inheritance:

+ 30
- 12
docs/source/fastNLP.modules.aggregator.rst View File

@@ -1,36 +1,54 @@
fastNLP.modules.aggregator
===========================
fastNLP.modules.aggregator package
==================================


fastNLP.modules.aggregator.attention
-------------------------------------
Submodules
----------

fastNLP.modules.aggregator.attention module
-------------------------------------------


.. automodule:: fastNLP.modules.aggregator.attention .. automodule:: fastNLP.modules.aggregator.attention
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.aggregator.avg\_pool
-------------------------------------
fastNLP.modules.aggregator.avg\_pool module
-------------------------------------------


.. automodule:: fastNLP.modules.aggregator.avg_pool .. automodule:: fastNLP.modules.aggregator.avg_pool
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.aggregator.kmax\_pool
--------------------------------------
fastNLP.modules.aggregator.kmax\_pool module
--------------------------------------------


.. automodule:: fastNLP.modules.aggregator.kmax_pool .. automodule:: fastNLP.modules.aggregator.kmax_pool
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.aggregator.max\_pool
-------------------------------------
fastNLP.modules.aggregator.max\_pool module
-------------------------------------------


.. automodule:: fastNLP.modules.aggregator.max_pool .. automodule:: fastNLP.modules.aggregator.max_pool
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.aggregator.self\_attention
-------------------------------------------
fastNLP.modules.aggregator.self\_attention module
-------------------------------------------------


.. automodule:: fastNLP.modules.aggregator.self_attention .. automodule:: fastNLP.modules.aggregator.self_attention
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.modules.aggregator .. automodule:: fastNLP.modules.aggregator
:members: :members:
:undoc-members:
:show-inheritance:

+ 18
- 6
docs/source/fastNLP.modules.decoder.rst View File

@@ -1,18 +1,30 @@
fastNLP.modules.decoder
========================
fastNLP.modules.decoder package
===============================


fastNLP.modules.decoder.CRF
----------------------------
Submodules
----------

fastNLP.modules.decoder.CRF module
----------------------------------


.. automodule:: fastNLP.modules.decoder.CRF .. automodule:: fastNLP.modules.decoder.CRF
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.decoder.MLP
----------------------------
fastNLP.modules.decoder.MLP module
----------------------------------


.. automodule:: fastNLP.modules.decoder.MLP .. automodule:: fastNLP.modules.decoder.MLP
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.modules.decoder .. automodule:: fastNLP.modules.decoder
:members: :members:
:undoc-members:
:show-inheritance:

+ 54
- 20
docs/source/fastNLP.modules.encoder.rst View File

@@ -1,60 +1,94 @@
fastNLP.modules.encoder
========================
fastNLP.modules.encoder package
===============================


fastNLP.modules.encoder.char\_embedding
----------------------------------------
Submodules
----------

fastNLP.modules.encoder.char\_embedding module
----------------------------------------------


.. automodule:: fastNLP.modules.encoder.char_embedding .. automodule:: fastNLP.modules.encoder.char_embedding
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.conv
-----------------------------
fastNLP.modules.encoder.conv module
-----------------------------------


.. automodule:: fastNLP.modules.encoder.conv .. automodule:: fastNLP.modules.encoder.conv
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.conv\_maxpool
--------------------------------------
fastNLP.modules.encoder.conv\_maxpool module
--------------------------------------------


.. automodule:: fastNLP.modules.encoder.conv_maxpool .. automodule:: fastNLP.modules.encoder.conv_maxpool
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.embedding
----------------------------------
fastNLP.modules.encoder.embedding module
----------------------------------------


.. automodule:: fastNLP.modules.encoder.embedding .. automodule:: fastNLP.modules.encoder.embedding
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.linear
-------------------------------
fastNLP.modules.encoder.linear module
-------------------------------------


.. automodule:: fastNLP.modules.encoder.linear .. automodule:: fastNLP.modules.encoder.linear
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.lstm
-----------------------------
fastNLP.modules.encoder.lstm module
-----------------------------------


.. automodule:: fastNLP.modules.encoder.lstm .. automodule:: fastNLP.modules.encoder.lstm
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.masked\_rnn
------------------------------------
fastNLP.modules.encoder.masked\_rnn module
------------------------------------------


.. automodule:: fastNLP.modules.encoder.masked_rnn .. automodule:: fastNLP.modules.encoder.masked_rnn
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.transformer
------------------------------------
fastNLP.modules.encoder.star\_transformer module
------------------------------------------------

.. automodule:: fastNLP.modules.encoder.star_transformer
:members:
:undoc-members:
:show-inheritance:

fastNLP.modules.encoder.transformer module
------------------------------------------


.. automodule:: fastNLP.modules.encoder.transformer .. automodule:: fastNLP.modules.encoder.transformer
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.encoder.variational\_rnn
-----------------------------------------
fastNLP.modules.encoder.variational\_rnn module
-----------------------------------------------


.. automodule:: fastNLP.modules.encoder.variational_rnn .. automodule:: fastNLP.modules.encoder.variational_rnn
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.modules.encoder .. automodule:: fastNLP.modules.encoder
:members: :members:
:undoc-members:
:show-inheritance:

+ 25
- 8
docs/source/fastNLP.modules.rst View File

@@ -1,5 +1,8 @@
fastNLP.modules
================
fastNLP.modules package
=======================

Subpackages
-----------


.. toctree:: .. toctree::


@@ -7,24 +10,38 @@ fastNLP.modules
fastNLP.modules.decoder fastNLP.modules.decoder
fastNLP.modules.encoder fastNLP.modules.encoder


fastNLP.modules.dropout
------------------------
Submodules
----------

fastNLP.modules.dropout module
------------------------------


.. automodule:: fastNLP.modules.dropout .. automodule:: fastNLP.modules.dropout
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.other\_modules
-------------------------------
fastNLP.modules.other\_modules module
-------------------------------------


.. automodule:: fastNLP.modules.other_modules .. automodule:: fastNLP.modules.other_modules
:members: :members:
:undoc-members:
:show-inheritance:


fastNLP.modules.utils
----------------------
fastNLP.modules.utils module
----------------------------


.. automodule:: fastNLP.modules.utils .. automodule:: fastNLP.modules.utils
:members: :members:
:undoc-members:
:show-inheritance:



Module contents
---------------


.. automodule:: fastNLP.modules .. automodule:: fastNLP.modules
:members: :members:
:undoc-members:
:show-inheritance:

+ 11
- 2
docs/source/fastNLP.rst View File

@@ -1,13 +1,22 @@
fastNLP
========
fastNLP package
===============

Subpackages
-----------


.. toctree:: .. toctree::


fastNLP.api fastNLP.api
fastNLP.automl
fastNLP.core fastNLP.core
fastNLP.io fastNLP.io
fastNLP.models fastNLP.models
fastNLP.modules fastNLP.modules


Module contents
---------------

.. automodule:: fastNLP .. automodule:: fastNLP
:members: :members:
:undoc-members:
:show-inheritance:

+ 3
- 0
fastNLP/api/__init__.py View File

@@ -1 +1,4 @@
"""
这是 API 部分的注释
"""
from .api import CWS, POS, Parser from .api import CWS, POS, Parser

+ 15
- 11
fastNLP/api/api.py View File

@@ -1,3 +1,7 @@
"""
API.API 的文档

"""
import warnings import warnings


import torch import torch
@@ -184,17 +188,17 @@ class CWS(API):
""" """
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
分词文件应该为: 分词文件应该为:
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct
1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct
1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
以空行分割两个句子,有内容的每行有7列。 以空行分割两个句子,有内容的每行有7列。


:param filepath: str, 文件路径路径。 :param filepath: str, 文件路径路径。


+ 8
- 7
fastNLP/automl/enas_trainer.py View File

@@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer):
""" """
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。 最好的模型参数。
:return results: 返回一个字典类型的数据, 内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值
:return results: 返回一个字典类型的数据,
内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值


""" """
results = {} results = {}


+ 1
- 0
fastNLP/core/dataset.py View File

@@ -281,6 +281,7 @@ class DataSet(object):
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target (2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 (3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 :return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度

""" """
assert len(self)!=0, "Null DataSet cannot use apply()." assert len(self)!=0, "Null DataSet cannot use apply()."
if field_name not in self: if field_name not in self:


+ 10
- 6
fastNLP/core/fieldarray.py View File

@@ -48,12 +48,16 @@ class PadderBase:
class AutoPadder(PadderBase): class AutoPadder(PadderBase):
""" """
根据contents的数据自动判定是否需要做padding。 根据contents的数据自动判定是否需要做padding。
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding
(2) 如果元素类型为(np.int64, np.float64),
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad
1 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding
2 如果元素类型为(np.int64, np.float64),
2.1 如果该field的内容只有一个,比如为sequence_length, 则不进行padding
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad
""" """
def __init__(self, pad_val=0): def __init__(self, pad_val=0):
""" """


+ 7
- 8
fastNLP/core/instance.py View File

@@ -1,13 +1,12 @@
class Instance(object): class Instance(object):
"""An Instance is an example of data. """An Instance is an example of data.
Example::
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])

:param fields: a dict of (str: list).

Example::
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])
""" """


def __init__(self, **fields): def __init__(self, **fields):


+ 2
- 2
fastNLP/core/losses.py View File

@@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs):


:param predict: Tensor, model output :param predict: Tensor, model output
:param truth: Tensor, truth from dataset :param truth: Tensor, truth from dataset
:param **kwargs: extra arguments
:param kwargs: extra arguments
:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
""" """
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) return predict.view(-1, predict.size()[-1]), truth.view(-1, )
@@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs):


:param predict: Tensor, [batch_size , max_len , tag_size] :param predict: Tensor, [batch_size , max_len , tag_size]
:param truth: Tensor, [batch_size , max_len] :param truth: Tensor, [batch_size , max_len]
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected.
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected.


:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
""" """


+ 91
- 74
fastNLP/core/metrics.py View File

@@ -17,66 +17,72 @@ class MetricBase(object):
"""Base class for all metrics. """Base class for all metrics.


所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。
evaluate(xxx)中传入的是一个batch的数据。 evaluate(xxx)中传入的是一个batch的数据。
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值
以分类问题中,Accuracy计算为例 以分类问题中,Accuracy计算为例
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy
class Model(nn.Module):
def __init__(xxx):
# do something
def forward(self, xxx):
# do something
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy::
class Model(nn.Module):
def __init__(xxx):
# do something
def forward(self, xxx):
# do something
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target
对应的AccMetric可以按如下的定义
# version1, 只使用这一次
class AccMetric(MetricBase):
def __init__(self):
super().__init__()

# 根据你的情况自定义指标
self.corr_num = 0
self.total = 0

def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()

def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
对应的AccMetric可以按如下的定义, version1, 只使用这一次::
class AccMetric(MetricBase):
def __init__(self):
super().__init__()
# 根据你的情况自定义指标
self.corr_num = 0 self.corr_num = 0
self.total = 0 self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中


# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred
class AccMetric(MetricBase):
def __init__(self, label=None, pred=None):
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
# acc_metric = AccMetric(label='y', pred='pred_y')即可。
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对
# 应的的值
super().__init__()
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
# 如果没有注册该则效果与version1就是一样的

# 根据你的情况自定义指标
self.corr_num = 0
self.total = 0

def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()

def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()
def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
self.corr_num = 0
self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中


version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred::
class AccMetric(MetricBase):
def __init__(self, label=None, pred=None):
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
# acc_metric = AccMetric(label='y', pred='pred_y')即可。
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对
# 应的的值
super().__init__()
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
# 如果没有注册该则效果与version1就是一样的
# 根据你的情况自定义指标
self.corr_num = 0 self.corr_num = 0
self.total = 0 self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()
def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
self.corr_num = 0
self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中




``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``.
@@ -84,12 +90,12 @@ class MetricBase(object):
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``.
``MetricBase`` will do the following type checks: ``MetricBase`` will do the following type checks:


1. whether self.evaluate has varargs, which is not supported.
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``.
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``.
1. whether self.evaluate has varargs, which is not supported.
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``.
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``.


Besides, before passing params into self.evaluate, this function will filter out params from output_dict and Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering
will be conducted.) will be conducted.)


""" """
@@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase):
""" """
在序列标注问题中,以span的方式计算F, pre, rec. 在序列标注问题中,以span的方式计算F, pre, rec.
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例)
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
最后得到的metric结果为
{
'f': xxx, # 这里使用f考虑以后可以计算f_beta值
'pre': xxx,
'rec':xxx
}
若only_gross=False, 即还会返回各个label的metric统计值
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
最后得到的metric结果为::
{ {
'f': xxx,
'pre': xxx,
'rec':xxx,
'f-label': xxx,
'pre-label': xxx,
'rec-label':xxx,
...
}
'f': xxx, # 这里使用f考虑以后可以计算f_beta值
'pre': xxx,
'rec':xxx
}
若only_gross=False, 即还会返回各个label的metric统计值::
{
'f': xxx,
'pre': xxx,
'rec':xxx,
'f-label': xxx,
'pre-label': xxx,
'rec-label':xxx,
...
}


""" """
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None,
@@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase):
""" """
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B
+-------+---------+----------+----------+---------+---------+
| | next_B | next_M | next_E | next_S | end | | | next_B | next_M | next_E | next_S | end |
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:|
| start | 合法 | next_M=B | next_E=S | 合法 | - |
+=======+=========+==========+==========+=========+=========+
| start | 合法 | next_M=B | next_E=S | 合法 | -- |
+-------+---------+----------+----------+---------+---------+
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S |
+-------+---------+----------+----------+---------+---------+
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E |
+-------+---------+----------+----------+---------+---------+
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 |
+-------+---------+----------+----------+---------+---------+
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 |
+-------+---------+----------+----------+---------+---------+
举例: 举例:
prediction为BSEMS,会被认为是SSSSS. prediction为BSEMS,会被认为是SSSSS.




+ 71
- 68
fastNLP/core/trainer.py View File

@@ -66,28 +66,28 @@ class Trainer(object):
不足,通过设置batch_size=32, update_every=4达到目的 不足,通过设置batch_size=32, update_every=4达到目的
""" """
super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(train_data, DataSet): if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
# check metrics and dev_data # check metrics and dev_data
if (not metrics) and dev_data is not None: if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.") raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None): if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")
# check update every # check update every
assert update_every>=1, "update_every must be no less than 1."
assert update_every >= 1, "update_every must be no less than 1."
self.update_every = int(update_every) self.update_every = int(update_every)
# check save_path # check save_path
if not (save_path is None or isinstance(save_path, str)): if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.") raise ValueError("save_path can only be None or `str`.")
# prepare evaluate # prepare evaluate
metrics = _prepare_metrics(metrics) metrics = _prepare_metrics(metrics)
# parse metric_key # parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases. # increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default. # It is true by default.
@@ -97,19 +97,19 @@ class Trainer(object):
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
elif len(metrics) > 0: elif len(metrics) > 0:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')
# prepare loss # prepare loss
losser = _prepare_losser(loss) losser = _prepare_losser(loss)
# sampler check # sampler check
if sampler is not None and not isinstance(sampler, BaseSampler): if sampler is not None and not isinstance(sampler, BaseSampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))
if check_code_level > -1: if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level, metric_key=metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
self.model = model self.model = model
@@ -120,7 +120,7 @@ class Trainer(object):
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = save_path self.save_path = save_path
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) if validate_every!=0 else -1
self.validate_every = int(validate_every) if validate_every != 0 else -1
self.best_metric_indicator = None self.best_metric_indicator = None
self.best_dev_epoch = None self.best_dev_epoch = None
self.best_dev_step = None self.best_dev_step = None
@@ -129,19 +129,19 @@ class Trainer(object):
self.prefetch = prefetch self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self.n_steps = (len(self.train_data) // self.batch_size + int( self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
if optimizer is None: if optimizer is None:
optimizer = Adam(lr=0.01, weight_decay=0) optimizer = Adam(lr=0.01, weight_decay=0)
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
self.pbar = None self.pbar = None
self.print_every = abs(self.print_every) self.print_every = abs(self.print_every)
if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
data=self.dev_data, data=self.dev_data,
@@ -149,14 +149,13 @@ class Trainer(object):
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
verbose=0) verbose=0)
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)


def train(self, load_best_model=True): def train(self, load_best_model=True):
""" """


@@ -185,14 +184,15 @@ class Trainer(object):
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型


:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。
:return results: 返回一个字典类型的数据, 内含以下内容::
最好的模型参数。
:return results: 返回一个字典类型的数据,
内含以下内容::


seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值
seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值


""" """
results = {} results = {}
@@ -205,21 +205,22 @@ class Trainer(object):
self.model = self.model.cuda() self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device self._model_device = self.model.parameters().__next__().device
self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time() start_time = time.time()
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)
try: try:
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.on_train_end() self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e) self.callback_manager.on_exception(e)
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),)
print(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf), )
results['best_eval'] = self.best_dev_perf results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step results['best_step'] = self.best_dev_step
@@ -233,9 +234,9 @@ class Trainer(object):
finally: finally:
pass pass
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)
return results return results
def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
@@ -244,13 +245,13 @@ class Trainer(object):
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
start = time.time() start = time.time()
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar if isinstance(pbar, tqdm) else None self.pbar = pbar if isinstance(pbar, tqdm) else None
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch) prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1):
for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
@@ -262,22 +263,22 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)
# edit prediction # edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction) self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y).mean() loss = self._compute_loss(prediction, batch_y).mean()
avg_loss += loss.item() avg_loss += loss.item()
loss = loss/self.update_every
loss = loss / self.update_every
# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss) self._grad_backward(loss)
self.callback_manager.on_backward_end() self.callback_manager.on_backward_end()
self._update() self._update()
self.callback_manager.on_step_end() self.callback_manager.on_step_end()
if (self.step+1) % self.print_every == 0:
if (self.step + 1) % self.print_every == 0:
avg_loss = avg_loss / self.print_every avg_loss = avg_loss / self.print_every
if self.use_tqdm: if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss) print_output = "loss:{0:<6.5f}".format(avg_loss)
@@ -290,34 +291,34 @@ class Trainer(object):
pbar.set_postfix_str(print_output) pbar.set_postfix_str(print_output)
avg_loss = 0 avg_loss = 0
self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps) + \ self.n_steps) + \
self.tester._format_eval_results(eval_res)
self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n') pbar.write(eval_str + '\n')
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
self.pbar = None self.pbar = None
# ============ tqdm end ============== # # ============ tqdm end ============== #
def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin() self.callback_manager.on_valid_begin()
res = self.tester.test() res = self.tester.test()
is_better_eval = False is_better_eval = False
if self._better_eval_result(res): if self._better_eval_result(res):
if self.save_path is not None: if self.save_path is not None:
self._save_model(self.model, self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
else: else:
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()}
self.best_dev_perf = res self.best_dev_perf = res
@@ -327,7 +328,7 @@ class Trainer(object):
# get validation results; adjust optimizer # get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
return res return res
def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


@@ -339,21 +340,21 @@ class Trainer(object):
model.eval() model.eval()
else: else:
model.train() model.train()
def _update(self): def _update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
if (self.step+1)%self.update_every==0:
if (self.step + 1) % self.update_every == 0:
self.optimizer.step() self.optimizer.step()
def _data_forward(self, network, x): def _data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict): if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y
def _grad_backward(self, loss): def _grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.


@@ -361,10 +362,10 @@ class Trainer(object):


For PyTorch, just do "loss.backward()" For PyTorch, just do "loss.backward()"
""" """
if self.step%self.update_every==0:
if self.step % self.update_every == 0:
self.model.zero_grad() self.model.zero_grad()
loss.backward() loss.backward()
def _compute_loss(self, predict, truth): def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


@@ -373,7 +374,7 @@ class Trainer(object):
:return: a scalar :return: a scalar
""" """
return self.losser(predict, truth) return self.losser(predict, truth)
def _save_model(self, model, model_name, only_param=False): def _save_model(self, model, model_name, only_param=False):
""" 存储不含有显卡信息的state_dict或model """ 存储不含有显卡信息的state_dict或model
:param model: :param model:
@@ -394,7 +395,7 @@ class Trainer(object):
model.cpu() model.cpu()
torch.save(model, model_path) torch.save(model, model_path)
model.to(self._model_device) model.to(self._model_device)
def _load_model(self, model, model_name, only_param=False): def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型 # 返回bool值指示是否成功reload模型
if self.save_path is not None: if self.save_path is not None:
@@ -409,7 +410,7 @@ class Trainer(object):
else: else:
return False return False
return True return True
def _better_eval_result(self, metrics): def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.


@@ -437,6 +438,7 @@ class Trainer(object):
DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2



def _get_value_info(_dict): def _get_value_info(_dict):
# given a dict value, return information about this dict's value. Return list of str # given a dict value, return information about this dict's value. Return list of str
strs = [] strs = []
@@ -453,27 +455,28 @@ def _get_value_info(_dict):
strs.append(_str) strs.append(_str)
return strs return strs



def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None, dev_data=None, metric_key=None,
check_level=0): check_level=0):
# check get_loss 方法 # check get_loss 方法
model_devcie = model.parameters().__next__().device model_devcie = model.parameters().__next__().device
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) _move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check # forward check
if batch_count==0:
if batch_count == 0:
info_str = "" info_str = ""
input_fields = _get_value_info(batch_x) input_fields = _get_value_info(batch_x)
target_fields = _get_value_info(batch_y) target_fields = _get_value_info(batch_y)
if len(input_fields)>0:
if len(input_fields) > 0:
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) info_str += "input fields after batch(if batch size is {}):\n".format(batch_size)
info_str += "\n".join(input_fields) info_str += "\n".join(input_fields)
info_str += '\n' info_str += '\n'
else: else:
raise RuntimeError("There is no input field.") raise RuntimeError("There is no input field.")
if len(target_fields)>0:
if len(target_fields) > 0:
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) info_str += "target fields after batch(if batch size is {}):\n".format(batch_size)
info_str += "\n".join(target_fields) info_str += "\n".join(target_fields)
info_str += '\n' info_str += '\n'
@@ -481,14 +484,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
info_str += 'There is no target field.' info_str += 'There is no target field.'
print(info_str) print(info_str)
_check_forward_error(forward_func=model.forward, dataset=dataset, _check_forward_error(forward_func=model.forward, dataset=dataset,
batch_x=batch_x, check_level=check_level)
batch_x=batch_x, check_level=check_level)
refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
pred_dict = model(**refined_batch_x) pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward) func_signature = get_func_signature(model.forward)
if not isinstance(pred_dict, dict): if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")
# loss check # loss check
try: try:
loss = losser(pred_dict, batch_y) loss = losser(pred_dict, batch_y)
@@ -512,7 +515,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
model.zero_grad() model.zero_grad()
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break break
if dev_data is not None: if dev_data is not None:
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1) batch_size=batch_size, verbose=-1)
@@ -526,7 +529,7 @@ def _check_eval_results(metrics, metric_key, metric_list):
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 # metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple): if isinstance(metrics, tuple):
loss, metrics = metrics loss, metrics = metrics
if isinstance(metrics, dict): if isinstance(metrics, dict):
if len(metrics) == 1: if len(metrics) == 1:
# only single metric, just use it # only single metric, just use it
@@ -537,7 +540,7 @@ def _check_eval_results(metrics, metric_key, metric_list):
if metrics_name not in metrics: if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name] metric_dict = metrics[metrics_name]
if len(metric_dict) == 1: if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None: elif len(metric_dict) > 1 and metric_key is None:


+ 7
- 2
fastNLP/core/utils.py View File

@@ -197,17 +197,22 @@ def get_func_signature(func):


Given a function or method, return its signature. Given a function or method, return its signature.
For example: For example:
(1) function
1 function::
def func(a, b='a', *args): def func(a, b='a', *args):
xxxx xxxx
get_func_signature(func) # 'func(a, b='a', *args)' get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
2 method::
class Demo: class Demo:
def __init__(self): def __init__(self):
xxx xxx
def forward(self, a, b='a', **args) def forward(self, a, b='a', **args)
demo = Demo() demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method :param func: a function or a method
:return: str or None :return: str or None
""" """


+ 4
- 4
fastNLP/io/config_io.py View File

@@ -26,10 +26,10 @@ class ConfigLoader(BaseLoader):


:param str file_path: the path of config file :param str file_path: the path of config file
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` :param dict sections: the dict of ``{section_name(string): ConfigSection object}``
Example::
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
Example::
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})


""" """
assert isinstance(sections, dict) assert isinstance(sections, dict)


+ 27
- 20
fastNLP/io/dataset_loader.py View File

@@ -9,7 +9,7 @@ from fastNLP.io.base_loader import DataLoaderRegister
def convert_seq_dataset(data): def convert_seq_dataset(data):
"""Create an DataSet instance that contains no labels. """Create an DataSet instance that contains no labels.


:param data: list of list of strings, [num_examples, *].
:param data: list of list of strings, [num_examples, \*].
Example:: Example::


[ [
@@ -28,7 +28,7 @@ def convert_seq_dataset(data):
def convert_seq2tag_dataset(data): def convert_seq2tag_dataset(data):
"""Convert list of data into DataSet. """Convert list of data into DataSet.


:param data: list of list of strings, [num_examples, *].
:param data: list of list of strings, [num_examples, \*].
Example:: Example::


[ [
@@ -48,7 +48,7 @@ def convert_seq2tag_dataset(data):
def convert_seq2seq_dataset(data): def convert_seq2seq_dataset(data):
"""Convert list of data into DataSet. """Convert list of data into DataSet.


:param data: list of list of strings, [num_examples, *].
:param data: list of list of strings, [num_examples, \*].
Example:: Example::


[ [
@@ -177,18 +177,18 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')
class DummyPOSReader(DataSetLoader): class DummyPOSReader(DataSetLoader):
"""A simple reader for a dummy POS tagging dataset. """A simple reader for a dummy POS tagging dataset.


In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second
In these datasets, each line are divided by "\\\\t". The first Col is the vocabulary and the second
Col is the label. Different sentence are divided by an empty line. Col is the label. Different sentence are divided by an empty line.
E.g::
E.g::


Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3
Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3


In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label.
""" """
@@ -200,11 +200,13 @@ class DummyPOSReader(DataSetLoader):
""" """
:return data: three-level list :return data: three-level list
Example:: Example::
[ [
[ [word_11, word_12, ...], [label_1, label_1, ...] ], [ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ], [ [word_21, word_22, ...], [label_2, label_1, ...] ],
... ...
] ]
""" """
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
@@ -550,6 +552,7 @@ class SNLIDataSetReader(DataSetLoader):


:param data: A 3D tensor. :param data: A 3D tensor.
Example:: Example::
[ [
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
@@ -647,7 +650,7 @@ class NaiveCWSReader(DataSetLoader):
例如:: 例如::


这是 fastNLP , 一个 非常 good 的 包 . 这是 fastNLP , 一个 非常 good 的 包 .
或者,即每个part后面还有一个pos tag 或者,即每个part后面还有一个pos tag
例如:: 例如::


@@ -661,12 +664,15 @@ class NaiveCWSReader(DataSetLoader):


def load(self, filepath, in_word_splitter=None, cut_long_sent=False): def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
""" """
允许使用的情况有(默认以\t或空格作为seg)
允许使用的情况有(默认以\\\\t或空格作为seg)::
这是 fastNLP , 一个 非常 good 的 包 . 这是 fastNLP , 一个 非常 good 的 包 .
和::
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]

:param filepath: :param filepath:
:param in_word_splitter: :param in_word_splitter:
:param cut_long_sent: :param cut_long_sent:
@@ -737,11 +743,12 @@ class ZhConllPOSReader(object):


def load(self, path): def load(self, path):
""" """
返回的DataSet, 包含以下的field
返回的DataSet, 包含以下的field::
words:list of str, words:list of str,
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
::
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即::


1 编者按 编者按 NN O 11 nmod:topic 1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct 2 : : PU O 11 punct


+ 1
- 1
fastNLP/io/embed_loader.py View File

@@ -132,7 +132,7 @@ class EmbedLoader(BaseLoader):
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'):
""" """
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining
embedding are initialized from a normal distribution which has the mean and std of the found words vectors.
embedding are initialized from a normal distribution which has the mean and std of the found words vectors.
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). The embedding type is determined automatically, support glove and word2vec(the first line only has two elements).


:param embed_filepath: str, where to read pretrain embedding :param embed_filepath: str, where to read pretrain embedding


+ 7
- 5
fastNLP/io/model_io.py View File

@@ -31,16 +31,18 @@ class ModelLoader(BaseLoader):


class ModelSaver(object): class ModelSaver(object):
"""Save a model """Save a model
Example::


:param str save_path: the path to the saving directory.
Example::

saver = ModelSaver("./save/model_ckpt_100.pkl")
saver.save_pytorch(model)
saver = ModelSaver("./save/model_ckpt_100.pkl")
saver.save_pytorch(model)


""" """


def __init__(self, save_path): def __init__(self, save_path):
"""

:param save_path: the path to the saving directory.
"""
self.save_path = save_path self.save_path = save_path


def save_pytorch(self, model, param_only=True): def save_pytorch(self, model, param_only=True):


+ 10
- 3
fastNLP/models/char_language_model.py View File

@@ -20,16 +20,23 @@ class Highway(nn.Module):


class CharLM(nn.Module): class CharLM(nn.Module):
"""CNN + highway network + LSTM """CNN + highway network + LSTM
# Input:
# Input::
4D tensor with shape [batch_size, in_channel, height, width] 4D tensor with shape [batch_size, in_channel, height, width]
# Output:
# Output::
2D Tensor with shape [batch_size, vocab_size] 2D Tensor with shape [batch_size, vocab_size]
# Arguments:
# Arguments::
char_emb_dim: the size of each character's attention char_emb_dim: the size of each character's attention
word_emb_dim: the size of each word's attention word_emb_dim: the size of each word's attention
vocab_size: num of unique words vocab_size: num of unique words
num_char: num of characters num_char: num of characters
use_gpu: True or False use_gpu: True or False
""" """


def __init__(self, char_emb_dim, word_emb_dim, def __init__(self, char_emb_dim, word_emb_dim,


+ 8
- 7
fastNLP/models/enas_trainer.py View File

@@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer):
""" """
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。 最好的模型参数。
:return results: 返回一个字典类型的数据, 内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值
:return results: 返回一个字典类型的数据,
内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值


""" """
results = {} results = {}


+ 19
- 19
fastNLP/models/snli.py View File

@@ -1,6 +1,5 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F


from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder as Decoder from fastNLP.modules import decoder as Decoder
@@ -40,7 +39,7 @@ class ESIM(BaseModel):
batch_first=self.batch_first, bidirectional=True batch_first=self.batch_first, bidirectional=True
) )


self.bi_attention = Aggregator.Bi_Attention()
self.bi_attention = Aggregator.BiAttention()
self.mean_pooling = Aggregator.MeanPoolWithMask() self.mean_pooling = Aggregator.MeanPoolWithMask()
self.max_pooling = Aggregator.MaxPoolWithMask() self.max_pooling = Aggregator.MaxPoolWithMask()


@@ -53,23 +52,23 @@ class ESIM(BaseModel):


self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout)


def forward(self, premise, hypothesis, premise_len, hypothesis_len):
def forward(self, words1, words2, seq_len1, seq_len2):
""" Forward function """ Forward function


:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)].
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)].
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL].
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL].
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)].
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)].
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B].
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B].
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)].
""" """


premise0 = self.embedding_layer(self.embedding(premise))
hypothesis0 = self.embedding_layer(self.embedding(hypothesis))
premise0 = self.embedding_layer(self.embedding(words1))
hypothesis0 = self.embedding_layer(self.embedding(words2))


_BP, _PSL, _HP = premise0.size() _BP, _PSL, _HP = premise0.size()
_BH, _HSL, _HH = hypothesis0.size() _BH, _HSL, _HH = hypothesis0.size()
_BPL, _PLL = premise_len.size()
_HPL, _HLL = hypothesis_len.size()
_BPL, _PLL = seq_len1.size()
_HPL, _HLL = seq_len2.size()


assert _BP == _BH and _BPL == _HPL and _BP == _BPL assert _BP == _BH and _BPL == _HPL and _BP == _BPL
assert _HP == _HH assert _HP == _HH
@@ -84,7 +83,7 @@ class ESIM(BaseModel):
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H]
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H]


ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len)
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2)


ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H]
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H]
@@ -98,17 +97,18 @@ class ESIM(BaseModel):
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H]
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H]


va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H]
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H]


v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H]


prediction = F.tanh(self.output(v)) # prediction: [B, N]
prediction = torch.tanh(self.output(v)) # prediction: [B, N]


return {'pred': prediction} return {'pred': prediction}


def predict(self, premise, hypothesis, premise_len, hypothesis_len):
return self.forward(premise, hypothesis, premise_len, hypothesis_len)
def predict(self, words1, words2, seq_len1, seq_len2):
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred']
return torch.argmax(prediction, dim=-1)



+ 1
- 1
fastNLP/modules/aggregator/__init__.py View File

@@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask
from .kmax_pool import KMaxPool from .kmax_pool import KMaxPool


from .attention import Attention from .attention import Attention
from .attention import Bi_Attention
from .attention import BiAttention
from .self_attention import SelfAttention from .self_attention import SelfAttention



+ 25
- 11
fastNLP/modules/aggregator/attention.py View File

@@ -23,9 +23,9 @@ class Attention(torch.nn.Module):
raise NotImplementedError raise NotImplementedError




class DotAtte(nn.Module):
class DotAttention(nn.Module):
def __init__(self, key_size, value_size, dropout=0.1): def __init__(self, key_size, value_size, dropout=0.1):
super(DotAtte, self).__init__()
super(DotAttention, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
self.scale = math.sqrt(key_size) self.scale = math.sqrt(key_size)
@@ -48,7 +48,7 @@ class DotAtte(nn.Module):
return torch.matmul(output, V) return torch.matmul(output, V)




class MultiHeadAtte(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
""" """


@@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module):
:param num_head: int,head的数量。 :param num_head: int,head的数量。
:param dropout: float。 :param dropout: float。
""" """
super(MultiHeadAtte, self).__init__()
super(MultiHeadAttention, self).__init__()
self.input_size = input_size self.input_size = input_size
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
@@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module):
self.q_in = nn.Linear(input_size, in_size) self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size)
self.attention = DotAtte(key_size=key_size, value_size=value_size)
self.attention = DotAttention(key_size=key_size, value_size=value_size)
self.out = nn.Linear(value_size * num_head, input_size) self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout) self.drop = TimestepDropout(dropout)
self.reset_parameters() self.reset_parameters()
@@ -109,16 +109,30 @@ class MultiHeadAtte(nn.Module):
return output return output




class Bi_Attention(nn.Module):
class BiAttention(nn.Module):
"""Bi Attention module
Calculate Bi Attention matrix `e`
.. math::
\begin{array}{ll} \\
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\
a_i =
b_j =
\end{array}
"""

def __init__(self): def __init__(self):
super(Bi_Attention, self).__init__()
super(BiAttention, self).__init__()
self.inf = 10e12 self.inf = 10e12


def forward(self, in_x1, in_x2, x1_len, x2_len): def forward(self, in_x1, in_x2, x1_len, x2_len):
# in_x1: [batch_size, x1_seq_len, hidden_size]
# in_x2: [batch_size, x2_seq_len, hidden_size]
# x1_len: [batch_size, x1_seq_len]
# x2_len: [batch_size, x2_seq_len]
"""
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
"""


assert in_x1.size()[0] == in_x2.size()[0] assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2] assert in_x1.size()[2] == in_x2.size()[2]


+ 2
- 2
fastNLP/modules/encoder/transformer.py View File

@@ -1,6 +1,6 @@
from torch import nn from torch import nn


from ..aggregator.attention import MultiHeadAtte
from ..aggregator.attention import MultiHeadAttention
from ..dropout import TimestepDropout from ..dropout import TimestepDropout




@@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module):
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size) self.norm1 = nn.LayerNorm(model_size)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(), nn.ReLU(),


Loading…
Cancel
Save