Browse Source

Merge branch 'dev0.4.0' of github.com:fastnlp/fastNLP into dev0.4.0

tags/v0.4.10
yh_cc 6 years ago
parent
commit
b80f018c69
64 changed files with 1167 additions and 941 deletions
  1. +0
    -7
      docs/source/fastNLP.models.base_model.rst
  2. +0
    -7
      docs/source/fastNLP.models.bert.rst
  3. +0
    -7
      docs/source/fastNLP.models.enas_controller.rst
  4. +0
    -7
      docs/source/fastNLP.models.enas_model.rst
  5. +0
    -7
      docs/source/fastNLP.models.enas_trainer.rst
  6. +0
    -7
      docs/source/fastNLP.models.enas_utils.rst
  7. +0
    -6
      docs/source/fastNLP.models.rst
  8. +1
    -1
      docs/source/fastNLP.modules.decoder.crf.rst
  9. +1
    -1
      docs/source/fastNLP.modules.decoder.mlp.rst
  10. +2
    -2
      docs/source/fastNLP.modules.decoder.rst
  11. +2
    -2
      fastNLP/__init__.py
  12. +11
    -8
      fastNLP/core/batch.py
  13. +6
    -2
      fastNLP/core/callback.py
  14. +6
    -3
      fastNLP/core/dataset.py
  15. +53
    -45
      fastNLP/core/field.py
  16. +3
    -1
      fastNLP/core/instance.py
  17. +12
    -1
      fastNLP/core/losses.py
  18. +106
    -98
      fastNLP/core/metrics.py
  19. +14
    -6
      fastNLP/core/optimizer.py
  20. +6
    -1
      fastNLP/core/predictor.py
  21. +7
    -3
      fastNLP/core/sampler.py
  22. +16
    -12
      fastNLP/core/tester.py
  23. +7
    -5
      fastNLP/core/trainer.py
  24. +86
    -72
      fastNLP/core/utils.py
  25. +16
    -0
      fastNLP/core/vocabulary.py
  26. +2
    -1
      fastNLP/io/__init__.py
  27. +13
    -6
      fastNLP/io/base_loader.py
  28. +36
    -28
      fastNLP/io/config_io.py
  29. +1
    -0
      fastNLP/io/dataset_loader.py
  30. +30
    -26
      fastNLP/io/embed_loader.py
  31. +10
    -5
      fastNLP/io/model_io.py
  32. +18
    -2
      fastNLP/models/__init__.py
  33. +6
    -6
      fastNLP/models/base_model.py
  34. +87
    -70
      fastNLP/models/biaffine_parser.py
  35. +8
    -7
      fastNLP/models/cnn_text_classification.py
  36. +1
    -0
      fastNLP/models/enas_controller.py
  37. +71
    -68
      fastNLP/models/enas_model.py
  38. +69
    -72
      fastNLP/models/enas_trainer.py
  39. +0
    -3
      fastNLP/models/enas_utils.py
  40. +13
    -5
      fastNLP/models/sequence_labeling.py
  41. +33
    -30
      fastNLP/models/snli.py
  42. +41
    -28
      fastNLP/models/star_transformer.py
  43. +21
    -15
      fastNLP/modules/__init__.py
  44. +7
    -7
      fastNLP/modules/aggregator/__init__.py
  45. +36
    -26
      fastNLP/modules/aggregator/attention.py
  46. +7
    -2
      fastNLP/modules/aggregator/pooling.py
  47. +5
    -5
      fastNLP/modules/decoder/__init__.py
  48. +52
    -47
      fastNLP/modules/decoder/crf.py
  49. +15
    -27
      fastNLP/modules/decoder/mlp.py
  50. +13
    -10
      fastNLP/modules/decoder/utils.py
  51. +4
    -2
      fastNLP/modules/dropout.py
  52. +25
    -8
      fastNLP/modules/encoder/__init__.py
  53. +22
    -9
      fastNLP/modules/encoder/char_encoder.py
  54. +15
    -13
      fastNLP/modules/encoder/conv_maxpool.py
  55. +11
    -7
      fastNLP/modules/encoder/embedding.py
  56. +8
    -2
      fastNLP/modules/encoder/lstm.py
  57. +39
    -32
      fastNLP/modules/encoder/star_transformer.py
  58. +10
    -6
      fastNLP/modules/encoder/transformer.py
  59. +71
    -43
      fastNLP/modules/encoder/variational_rnn.py
  60. +1
    -1
      fastNLP/modules/utils.py
  61. +3
    -3
      reproduction/Chinese_word_segmentation/models/cws_model.py
  62. +2
    -2
      reproduction/Chinese_word_segmentation/models/cws_transformer.py
  63. +1
    -1
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  64. +5
    -5
      test/modules/decoder/test_CRF.py

+ 0
- 7
docs/source/fastNLP.models.base_model.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.base\_model
==========================

.. automodule:: fastNLP.models.base_model
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/source/fastNLP.models.bert.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.bert
===================

.. automodule:: fastNLP.models.bert
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/source/fastNLP.models.enas_controller.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.enas\_controller
===============================

.. automodule:: fastNLP.models.enas_controller
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/source/fastNLP.models.enas_model.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.enas\_model
==========================

.. automodule:: fastNLP.models.enas_model
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/source/fastNLP.models.enas_trainer.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.enas\_trainer
============================

.. automodule:: fastNLP.models.enas_trainer
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/source/fastNLP.models.enas_utils.rst View File

@@ -1,7 +0,0 @@
fastNLP.models.enas\_utils
==========================

.. automodule:: fastNLP.models.enas_utils
:members:
:undoc-members:
:show-inheritance:

+ 0
- 6
docs/source/fastNLP.models.rst View File

@@ -12,14 +12,8 @@ fastNLP.models
.. toctree:: .. toctree::
:titlesonly: :titlesonly:


fastNLP.models.base_model
fastNLP.models.bert
fastNLP.models.biaffine_parser fastNLP.models.biaffine_parser
fastNLP.models.cnn_text_classification fastNLP.models.cnn_text_classification
fastNLP.models.enas_controller
fastNLP.models.enas_model
fastNLP.models.enas_trainer
fastNLP.models.enas_utils
fastNLP.models.sequence_labeling fastNLP.models.sequence_labeling
fastNLP.models.snli fastNLP.models.snli
fastNLP.models.star_transformer fastNLP.models.star_transformer


docs/source/fastNLP.modules.decoder.CRF.rst → docs/source/fastNLP.modules.decoder.crf.rst View File

@@ -1,7 +1,7 @@
fastNLP.modules.decoder.CRF fastNLP.modules.decoder.CRF
=========================== ===========================


.. automodule:: fastNLP.modules.decoder.CRF
.. automodule:: fastNLP.modules.decoder.crf
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

docs/source/fastNLP.modules.decoder.MLP.rst → docs/source/fastNLP.modules.decoder.mlp.rst View File

@@ -1,7 +1,7 @@
fastNLP.modules.decoder.MLP fastNLP.modules.decoder.MLP
=========================== ===========================


.. automodule:: fastNLP.modules.decoder.MLP
.. automodule:: fastNLP.modules.decoder.mlp
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

+ 2
- 2
docs/source/fastNLP.modules.decoder.rst View File

@@ -12,7 +12,7 @@ fastNLP.modules.decoder
.. toctree:: .. toctree::
:titlesonly: :titlesonly:


fastNLP.modules.decoder.CRF
fastNLP.modules.decoder.MLP
fastNLP.modules.decoder.crf
fastNLP.modules.decoder.mlp
fastNLP.modules.decoder.utils fastNLP.modules.decoder.utils



+ 2
- 2
fastNLP/__init__.py View File

@@ -52,8 +52,8 @@ __all__ = [
"cache_results" "cache_results"
] ]
__version__ = '0.4.0'

from .core import * from .core import *
from . import models from . import models
from . import modules from . import modules

__version__ = '0.4.0'

+ 11
- 8
fastNLP/core/batch.py View File

@@ -2,14 +2,18 @@
batch 模块实现了 fastNLP 所需的 Batch 类。 batch 模块实现了 fastNLP 所需的 Batch 类。


""" """
__all__ = ["Batch"]
import numpy as np
import torch
__all__ = [
"Batch"
]

import atexit import atexit
from queue import Empty, Full


from .sampler import RandomSampler, Sampler
import numpy as np
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from queue import Empty, Full

from .sampler import RandomSampler


_python_is_exit = False _python_is_exit = False


@@ -120,7 +124,7 @@ class Batch(object):
:return list(int) indexes: 下标序列 :return list(int) indexes: 下标序列
""" """
return self.cur_batch_indices return self.cur_batch_indices
@staticmethod @staticmethod
def _run_fetch(batch, q): def _run_fetch(batch, q):
try: try:
@@ -145,7 +149,7 @@ class Batch(object):
q.put(e) q.put(e)
finally: finally:
q.join() q.join()
@staticmethod @staticmethod
def _run_batch_iter(batch): def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10) q = mp.JoinableQueue(maxsize=10)
@@ -182,4 +186,3 @@ def _to_tensor(batch, dtype):
except: except:
pass pass
return batch return batch


+ 6
- 2
fastNLP/core/callback.py View File

@@ -60,16 +60,20 @@ __all__ = [
"CallbackException", "CallbackException",
"EarlyStopError" "EarlyStopError"
] ]

import os import os

import torch import torch
from ..io.model_io import ModelSaver, ModelLoader


try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
tensorboardX_flag = True tensorboardX_flag = True
except: except:
tensorboardX_flag = False tensorboardX_flag = False


from ..io.model_io import ModelSaver, ModelLoader



class Callback(object): class Callback(object):
""" """
@@ -587,7 +591,7 @@ class TensorboardCallback(Callback):
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)
else: else:
self._summary_writer = None self._summary_writer = None
def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False: if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图 # tesorboardX 这里有大bug,暂时没法画模型图


+ 6
- 3
fastNLP/core/dataset.py View File

@@ -272,11 +272,14 @@




""" """
__all__ = ["DataSet"]
__all__ = [
"DataSet"
]

import _pickle as pickle import _pickle as pickle
import warnings


import numpy as np import numpy as np
import warnings


from .field import AutoPadder from .field import AutoPadder
from .field import FieldArray from .field import FieldArray
@@ -863,4 +866,4 @@ class DataSet(object):
with open(path, 'rb') as f: with open(path, 'rb') as f:
d = pickle.load(f) d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d
return d

+ 53
- 45
fastNLP/core/field.py View File

@@ -3,10 +3,16 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas
原理部分请参考 :doc:`fastNLP.core.dataset` 原理部分请参考 :doc:`fastNLP.core.dataset`


""" """
__all__ = [
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder"
]


from copy import deepcopy


import numpy as np import numpy as np
from copy import deepcopy




class FieldArray(object): class FieldArray(object):
@@ -24,6 +30,7 @@ class FieldArray(object):
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor,
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。
""" """
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
self.name = name self.name = name
if isinstance(content, list): if isinstance(content, list):
@@ -41,7 +48,7 @@ class FieldArray(object):
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Cannot initialize FieldArray with empty list.") raise RuntimeError("Cannot initialize FieldArray with empty list.")
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list self.content_dim = None # 表示content是多少维的list
if padder is None: if padder is None:
@@ -51,27 +58,27 @@ class FieldArray(object):
padder = deepcopy(padder) padder = deepcopy(padder)
self.set_padder(padder) self.set_padder(padder)
self.ignore_type = ignore_type self.ignore_type = ignore_type
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array
self.pytype = None self.pytype = None
self.dtype = None self.dtype = None
self._is_input = None self._is_input = None
self._is_target = None self._is_target = None
if is_input is not None or is_target is not None: if is_input is not None or is_target is not None:
self.is_input = is_input self.is_input = is_input
self.is_target = is_target self.is_target = is_target
def _set_dtype(self): def _set_dtype(self):
if self.ignore_type is False: if self.ignore_type is False:
self.pytype = self._type_detection(self.content) self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype) self.dtype = self._map_to_np_type(self.pytype)
@property @property
def is_input(self): def is_input(self):
return self._is_input return self._is_input
@is_input.setter @is_input.setter
def is_input(self, value): def is_input(self, value):
""" """
@@ -80,11 +87,11 @@ class FieldArray(object):
if value is True: if value is True:
self._set_dtype() self._set_dtype()
self._is_input = value self._is_input = value
@property @property
def is_target(self): def is_target(self):
return self._is_target return self._is_target
@is_target.setter @is_target.setter
def is_target(self, value): def is_target(self, value):
""" """
@@ -93,7 +100,7 @@ class FieldArray(object):
if value is True: if value is True:
self._set_dtype() self._set_dtype()
self._is_target = value self._is_target = value
def _type_detection(self, content): def _type_detection(self, content):
""" """
当该field被设置为is_input或者is_target时被调用 当该field被设置为is_input或者is_target时被调用
@@ -101,9 +108,9 @@ class FieldArray(object):
""" """
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Empty list in Field {}.".format(self.name)) raise RuntimeError("Empty list in Field {}.".format(self.name))
type_set = set([type(item) for item in content]) type_set = set([type(item) for item in content])
if list in type_set: if list in type_set:
if len(type_set) > 1: if len(type_set) > 1:
# list 跟 非list 混在一起 # list 跟 非list 混在一起
@@ -139,7 +146,7 @@ class FieldArray(object):
self.name, self.BASIC_TYPES, content_type)) self.name, self.BASIC_TYPES, content_type))
self.content_dim = 1 self.content_dim = 1
return self._basic_type_detection(type_set) return self._basic_type_detection(type_set)
def _basic_type_detection(self, type_set): def _basic_type_detection(self, type_set):
""" """
:param type_set: a set of Python types :param type_set: a set of Python types
@@ -158,7 +165,7 @@ class FieldArray(object):
else: else:
# str, int, float混在一起 # str, int, float混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
def _1d_list_check(self, val): def _1d_list_check(self, val):
"""如果不是1D list就报错 """如果不是1D list就报错
""" """
@@ -168,7 +175,7 @@ class FieldArray(object):
self._basic_type_detection(type_set) self._basic_type_detection(type_set)
# otherwise: _basic_type_detection will raise error # otherwise: _basic_type_detection will raise error
return True return True
def _2d_list_check(self, val): def _2d_list_check(self, val):
"""如果不是2D list 就报错 """如果不是2D list 就报错
""" """
@@ -181,15 +188,15 @@ class FieldArray(object):
inner_type_set.add(type(obj)) inner_type_set.add(type(obj))
self._basic_type_detection(inner_type_set) self._basic_type_detection(inner_type_set)
return True return True
@staticmethod @staticmethod
def _map_to_np_type(basic_type): def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
return type_mapping[basic_type] return type_mapping[basic_type]
def __repr__(self): def __repr__(self):
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) return "FieldArray {}: {}".format(self.name, self.content.__repr__())
def append(self, val): def append(self, val):
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有
的内容是匹配的。 的内容是匹配的。
@@ -208,7 +215,7 @@ class FieldArray(object):
else: else:
raise RuntimeError( raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
if self.is_input is True or self.is_target is True: if self.is_input is True or self.is_target is True:
if type(val) == list: if type(val) == list:
if len(val) == 0: if len(val) == 0:
@@ -231,14 +238,14 @@ class FieldArray(object):
raise RuntimeError( raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
self.content.append(val) self.content.append(val)
def __getitem__(self, indices): def __getitem__(self, indices):
return self.get(indices, pad=False) return self.get(indices, pad=False)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
assert isinstance(idx, int) assert isinstance(idx, int)
self.content[idx] = val self.content[idx] = val
def get(self, indices, pad=True): def get(self, indices, pad=True):
""" """
根据给定的indices返回内容 根据给定的indices返回内容
@@ -251,13 +258,13 @@ class FieldArray(object):
return self.content[indices] return self.content[indices]
if self.is_input is False and self.is_target is False: if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
contents = [self.content[i] for i in indices] contents = [self.content[i] for i in indices]
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else: else:
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype)
def set_padder(self, padder): def set_padder(self, padder):
""" """
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
@@ -269,7 +276,7 @@ class FieldArray(object):
self.padder = deepcopy(padder) self.padder = deepcopy(padder)
else: else:
self.padder = None self.padder = None
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
""" """
修改padder的pad_val. 修改padder的pad_val.
@@ -279,8 +286,7 @@ class FieldArray(object):
if self.padder is not None: if self.padder is not None:
self.padder.set_pad_val(pad_val) self.padder.set_pad_val(pad_val)
return self return self


def __len__(self): def __len__(self):
""" """
Returns the size of FieldArray. Returns the size of FieldArray.
@@ -288,7 +294,7 @@ class FieldArray(object):
:return int length: :return int length:
""" """
return len(self.content) return len(self.content)
def to(self, other): def to(self, other):
""" """
将other的属性复制给本FieldArray(other必须为FieldArray类型). 将other的属性复制给本FieldArray(other必须为FieldArray类型).
@@ -298,14 +304,15 @@ class FieldArray(object):
:return: :class:`~fastNLP.FieldArray` :return: :class:`~fastNLP.FieldArray`
""" """
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))
self.is_input = other.is_input self.is_input = other.is_input
self.is_target = other.is_target self.is_target = other.is_target
self.padder = other.padder self.padder = other.padder
self.ignore_type = other.ignore_type self.ignore_type = other.ignore_type
return self return self



def _is_iterable(content): def _is_iterable(content):
try: try:
_ = (e for e in content) _ = (e for e in content)
@@ -331,13 +338,13 @@ class Padder:
:return: np.array([padded_element]) :return: np.array([padded_element])
""" """
def __init__(self, pad_val=0, **kwargs): def __init__(self, pad_val=0, **kwargs):
self.pad_val = pad_val self.pad_val = pad_val
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
self.pad_val = pad_val self.pad_val = pad_val
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
""" """
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。
@@ -396,13 +403,13 @@ class AutoPadder(Padder):
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
""" """
def __init__(self, pad_val=0): def __init__(self, pad_val=0):
""" """
:param pad_val: int, padding的位置使用该index :param pad_val: int, padding的位置使用该index
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
def _is_two_dimension(self, contents): def _is_two_dimension(self, contents):
""" """
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度
@@ -416,7 +423,7 @@ class AutoPadder(Padder):
return False return False
return True return True
return False return False
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
if not _is_iterable(contents[0]): if not _is_iterable(contents[0]):
@@ -458,6 +465,7 @@ class EngChar2DPadder(Padder):
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder


""" """
def __init__(self, pad_val=0, pad_length=0): def __init__(self, pad_val=0, pad_length=0):
""" """
:param pad_val: int, pad的位置使用该index :param pad_val: int, pad的位置使用该index
@@ -465,9 +473,9 @@ class EngChar2DPadder(Padder):
都pad或截取到该长度. 都pad或截取到该长度.
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
self.pad_length = pad_length self.pad_length = pad_length
def _exactly_three_dims(self, contents, field_name): def _exactly_three_dims(self, contents, field_name):
""" """
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character
@@ -486,10 +494,10 @@ class EngChar2DPadder(Padder):
value = value[0] value = value[0]
except: except:
raise ValueError("Field:{} only has two dimensions.".format(field_name)) raise ValueError("Field:{} only has two dimensions.".format(field_name))
if _is_iterable(value): if _is_iterable(value):
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) raise ValueError("Field:{} has more than 3 dimension.".format(field_name))
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
""" """
期望输入类似于 期望输入类似于
@@ -516,12 +524,12 @@ class EngChar2DPadder(Padder):
max_sent_length = max(len(word_lst) for word_lst in contents) max_sent_length = max(len(word_lst) for word_lst in contents)
batch_size = len(contents) batch_size = len(contents)
dtype = type(contents[0][0][0]) dtype = type(contents[0][0][0])
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
dtype=dtype)
dtype=dtype)
for b_idx, word_lst in enumerate(contents): for b_idx, word_lst in enumerate(contents):
for c_idx, char_lst in enumerate(word_lst): for c_idx, char_lst in enumerate(word_lst):
chars = char_lst[:max_char_length] chars = char_lst[:max_char_length]
padded_array[b_idx, c_idx, :len(chars)] = chars padded_array[b_idx, c_idx, :len(chars)] = chars
return padded_array
return padded_array

+ 3
- 1
fastNLP/core/instance.py View File

@@ -3,7 +3,9 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格


""" """
__all__ = ["Instance"]
__all__ = [
"Instance"
]




class Instance(object): class Instance(object):


+ 12
- 1
fastNLP/core/losses.py View File

@@ -2,7 +2,18 @@
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。


""" """
__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"]
__all__ = [
"LossBase",
"LossFunc",
"LossInForward",
"CrossEntropyLoss",
"BCELoss",
"L1Loss",
"NLLLoss"
]

import inspect import inspect
from collections import defaultdict from collections import defaultdict




+ 106
- 98
fastNLP/core/metrics.py View File

@@ -2,6 +2,13 @@
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。


""" """
__all__ = [
"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric"
]

import inspect import inspect
from collections import defaultdict from collections import defaultdict


@@ -106,16 +113,17 @@ class MetricBase(object):
self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值


""" """
def __init__(self): def __init__(self):
self.param_map = {} # key is param in function, value is input param. self.param_map = {} # key is param in function, value is input param.
self._checked = False self._checked = False
def evaluate(self, *args, **kwargs): def evaluate(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented
def _init_param_map(self, key_map=None, **kwargs): def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map """检查key_map和其他参数map,并将这些映射关系添加到self.param_map


@@ -148,7 +156,7 @@ class MetricBase(object):
for value, key_set in value_counter.items(): for value, key_set in value_counter.items():
if len(key_set) > 1: if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
# check consistence between signature and param_map # check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self'] func_args = [arg for arg in func_spect.args if arg != 'self']
@@ -157,7 +165,7 @@ class MetricBase(object):
raise NameError( raise NameError(
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")
def _fast_param_map(self, pred_dict, target_dict): def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element such as pred_dict has one element, target_dict has one element
@@ -172,7 +180,7 @@ class MetricBase(object):
fast_param['target'] = list(target_dict.values())[0] fast_param['target'] = list(target_dict.values())[0]
return fast_param return fast_param
return fast_param return fast_param
def __call__(self, pred_dict, target_dict): def __call__(self, pred_dict, target_dict):
""" """
这个方法会调用self.evaluate 方法. 这个方法会调用self.evaluate 方法.
@@ -187,12 +195,12 @@ class MetricBase(object):
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容)
:return: :return:
""" """
fast_param = self._fast_param_map(pred_dict, target_dict) fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param: if fast_param:
self.evaluate(**fast_param) self.evaluate(**fast_param)
return return
if not self._checked: if not self._checked:
if not callable(self.evaluate): if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
@@ -202,14 +210,14 @@ class MetricBase(object):
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
if arg not in self.param_map: if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping. self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
# need to wrap inputs in dict. # need to wrap inputs in dict.
mapped_pred_dict = {} mapped_pred_dict = {}
mapped_target_dict = {} mapped_target_dict = {}
@@ -229,7 +237,7 @@ class MetricBase(object):
not_duplicate_flag += 1 not_duplicate_flag += 1
if not_duplicate_flag == 3: if not_duplicate_flag == 3:
duplicated.append(input_arg) duplicated.append(input_arg)
# missing # missing
if not self._checked: if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
@@ -240,23 +248,23 @@ class MetricBase(object):
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add `` # Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
f"in `{self.__class__.__name__}`)"
check_res = _CheckRes(missing=replaced_missing, check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused, unused=check_res.unused,
duplicated=duplicated, duplicated=duplicated,
required=check_res.required, required=check_res.required,
all_needed=check_res.all_needed, all_needed=check_res.all_needed,
varargs=check_res.varargs) varargs=check_res.varargs)
if check_res.missing or check_res.duplicated: if check_res.missing or check_res.duplicated:
raise _CheckError(check_res=check_res, raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.evaluate)) func_signature=_get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)
self.evaluate(**refined_args) self.evaluate(**refined_args)
self._checked = True self._checked = True
return return




@@ -271,15 +279,16 @@ class AccuracyMetric(MetricBase):
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
""" """
def __init__(self, pred=None, target=None, seq_len=None): def __init__(self, pred=None, target=None, seq_len=None):
super().__init__() super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len) self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0
def evaluate(self, pred, target, seq_len=None): def evaluate(self, pred, target, seq_len=None):
""" """
evaluate函数将针对一个批次的预测结果做评价指标的累计 evaluate函数将针对一个批次的预测结果做评价指标的累计
@@ -299,16 +308,16 @@ class AccuracyMetric(MetricBase):
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")
if seq_len is not None and not isinstance(seq_len, torch.Tensor): if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")
if seq_len is not None: if seq_len is not None:
masks = seq_len_to_mask(seq_len=seq_len) masks = seq_len_to_mask(seq_len=seq_len)
else: else:
masks = None masks = None
if pred.size() == target.size(): if pred.size() == target.size():
pass pass
elif len(pred.size()) == len(target.size()) + 1: elif len(pred.size()) == len(target.size()) + 1:
@@ -317,7 +326,7 @@ class AccuracyMetric(MetricBase):
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")
target = target.to(pred) target = target.to(pred)
if masks is not None: if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item()
@@ -325,7 +334,7 @@ class AccuracyMetric(MetricBase):
else: else:
self.acc_count += torch.sum(torch.eq(pred, target)).item() self.acc_count += torch.sum(torch.eq(pred, target)).item()
self.total += np.prod(list(pred.size())) self.total += np.prod(list(pred.size()))
def get_metric(self, reset=True): def get_metric(self, reset=True):
""" """
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
@@ -350,7 +359,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bmes_tag = None prev_bmes_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -358,14 +367,14 @@ def _bmes_tag_to_spans(tags, ignore_labels=None):
bmes_tag, label = tag[:1], tag[2:] bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'): if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
] ]




@@ -379,7 +388,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bmes_tag = None prev_bmes_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -387,16 +396,16 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None):
bmes_tag, label = tag[:1], tag[2:] bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'): if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
elif bmes_tag == 'o': elif bmes_tag == 'o':
pass pass
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
] ]




@@ -410,7 +419,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bio_tag = None prev_bio_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -418,14 +427,14 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
bio_tag, label = tag[:1], tag[2:] bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b': if bio_tag == 'b':
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]:
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
elif bio_tag == 'o': # o tag does not count
pass pass
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels]
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]




class SpanFPreRecMetric(MetricBase): class SpanFPreRecMetric(MetricBase):
@@ -470,16 +479,17 @@ class SpanFPreRecMetric(MetricBase):
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
""" """
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
only_gross=True, f_type='micro', beta=1):
encoding_type = encoding_type.lower() encoding_type = encoding_type.lower()
if not isinstance(tag_vocab, Vocabulary): if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'): if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
self.encoding_type = encoding_type self.encoding_type = encoding_type
if self.encoding_type == 'bmes': if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans self.tag_to_span_func = _bmes_tag_to_spans
@@ -489,22 +499,22 @@ class SpanFPreRecMetric(MetricBase):
self.tag_to_span_func = _bmeso_tag_to_spans self.tag_to_span_func = _bmeso_tag_to_spans
else: else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.f_type = f_type self.f_type = f_type
self.beta = beta self.beta = beta
self.beta_square = self.beta**2
self.beta_square = self.beta ** 2
self.only_gross = only_gross self.only_gross = only_gross
super().__init__() super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len) self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.tag_vocab = tag_vocab self.tag_vocab = tag_vocab
self._true_positives = defaultdict(int) self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int) self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int) self._false_negatives = defaultdict(int)
def evaluate(self, pred, target, seq_len): def evaluate(self, pred, target, seq_len):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 """evaluate函数将针对一个批次的预测结果做评价指标的累计


@@ -519,11 +529,11 @@ class SpanFPreRecMetric(MetricBase):
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")
if not isinstance(seq_len, torch.Tensor): if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")
if pred.size() == target.size() and len(target.size()) == 2: if pred.size() == target.size() and len(target.size()) == 2:
pass pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
@@ -536,20 +546,20 @@ class SpanFPreRecMetric(MetricBase):
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")
batch_size = pred.size(0) batch_size = pred.size(0)
pred = pred.tolist() pred = pred.tolist()
target = target.tolist() target = target.tolist()
for i in range(batch_size): for i in range(batch_size):
pred_tags = pred[i][:int(seq_len[i])] pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])] gold_tags = target[i][:int(seq_len[i])]
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
for span in pred_spans: for span in pred_spans:
if span in gold_spans: if span in gold_spans:
self._true_positives[span[0]] += 1 self._true_positives[span[0]] += 1
@@ -558,7 +568,7 @@ class SpanFPreRecMetric(MetricBase):
self._false_positives[span[0]] += 1 self._false_positives[span[0]] += 1
for span in gold_spans: for span in gold_spans:
self._false_negatives[span[0]] += 1 self._false_negatives[span[0]] += 1
def get_metric(self, reset=True): def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {} evaluate_result = {}
@@ -577,19 +587,19 @@ class SpanFPreRecMetric(MetricBase):
f_sum += f f_sum += f
pre_sum += pre pre_sum += pre
rec_sum + rec rec_sum + rec
if not self.only_gross and tag!='': # tag!=''防止无tag的情况
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag) f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag) pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag) rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f evaluate_result[f_key] = f
evaluate_result[pre_key] = pre evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec evaluate_result[rec_key] = rec
if self.f_type == 'macro': if self.f_type == 'macro':
evaluate_result['f'] = f_sum/len(tags)
evaluate_result['pre'] = pre_sum/len(tags)
evaluate_result['rec'] = rec_sum/len(tags)
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro': if self.f_type == 'micro':
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()), sum(self._false_negatives.values()),
@@ -597,17 +607,17 @@ class SpanFPreRecMetric(MetricBase):
evaluate_result['f'] = f evaluate_result['f'] = f
evaluate_result['pre'] = pre evaluate_result['pre'] = pre
evaluate_result['rec'] = rec evaluate_result['rec'] = rec
if reset: if reset:
self._true_positives = defaultdict(int) self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int) self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int) self._false_negatives = defaultdict(int)
for key, value in evaluate_result.items(): for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6) evaluate_result[key] = round(value, 6)
return evaluate_result return evaluate_result
def _compute_f_pre_rec(self, tp, fn, fp): def _compute_f_pre_rec(self, tp, fn, fp):
""" """


@@ -619,11 +629,10 @@ class SpanFPreRecMetric(MetricBase):
pre = tp / (fp + tp + 1e-13) pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13) rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)
return f, pre, rec return f, pre, rec





def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """


@@ -705,33 +714,33 @@ class SQuADMetric(MetricBase):
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
""" """
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, def __init__(self, pred1=None, pred2=None, target1=None, target2=None,
beta=1, right_open=True, print_predict_stat=False): beta=1, right_open=True, print_predict_stat=False):
super(SQuADMetric, self).__init__() super(SQuADMetric, self).__init__()
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2)
self.print_predict_stat = print_predict_stat self.print_predict_stat = print_predict_stat
self.no_ans_correct = 0 self.no_ans_correct = 0
self.no_ans_wrong = 0 self.no_ans_wrong = 0
self.has_ans_correct = 0 self.has_ans_correct = 0
self.has_ans_wrong = 0 self.has_ans_wrong = 0
self.has_ans_f = 0. self.has_ans_f = 0.
self.no2no = 0 self.no2no = 0
self.no2yes = 0 self.no2yes = 0
self.yes2no = 0 self.yes2no = 0
self.yes2yes = 0 self.yes2yes = 0
self.f_beta = beta self.f_beta = beta
self.right_open = right_open self.right_open = right_open
def evaluate(self, pred1, pred2, target1, target2): def evaluate(self, pred1, pred2, target1, target2):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 """evaluate函数将针对一个批次的预测结果做评价指标的累计


@@ -745,7 +754,7 @@ class SQuADMetric(MetricBase):
pred_end = pred2 pred_end = pred2
target_start = target1 target_start = target1
target_end = target2 target_end = target2
if len(pred_start.size()) == 2: if len(pred_start.size()) == 2:
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
else: else:
@@ -754,12 +763,12 @@ class SQuADMetric(MetricBase):
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
else: else:
end_inference = pred_end.cpu().tolist() end_inference = pred_end.cpu().tolist()
start, end = [], [] start, end = [], []
max_len = pred_start.size(1) max_len = pred_start.size(1)
t_start = target_start.cpu().tolist() t_start = target_start.cpu().tolist()
t_end = target_end.cpu().tolist() t_end = target_end.cpu().tolist()
for s, e in zip(start_inference, end_inference): for s, e in zip(start_inference, end_inference):
start.append(min(s, e)) start.append(min(s, e))
end.append(max(s, e)) end.append(max(s, e))
@@ -779,7 +788,7 @@ class SQuADMetric(MetricBase):
self.yes2no += 1 self.yes2no += 1
else: else:
self.yes2yes += 1 self.yes2yes += 1
if s == ts and e == te: if s == ts and e == te:
self.has_ans_correct += 1 self.has_ans_correct += 1
else: else:
@@ -787,29 +796,29 @@ class SQuADMetric(MetricBase):
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) a = [0] * s + [1] * (e - s) + [0] * (max_len - e)
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te)
a, b = torch.tensor(a), torch.tensor(b) a, b = torch.tensor(a), torch.tensor(b)
TP = int(torch.sum(a * b)) TP = int(torch.sum(a * b))
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0
if pre + rec > 0: if pre + rec > 0:
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec)
f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec)
else: else:
f = 0 f = 0
self.has_ans_f += f self.has_ans_f += f
def get_metric(self, reset=True): def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {} evaluate_result = {}
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0:
return evaluate_result return evaluate_result
evaluate_result['EM'] = 0 evaluate_result['EM'] = 0
evaluate_result[f'f_{self.f_beta}'] = 0 evaluate_result[f'f_{self.f_beta}'] = 0
flag = 0 flag = 0
if self.no_ans_correct + self.no_ans_wrong > 0: if self.no_ans_correct + self.no_ans_wrong > 0:
evaluate_result[f'noAns-f_{self.f_beta}'] = \ evaluate_result[f'noAns-f_{self.f_beta}'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
@@ -818,7 +827,7 @@ class SQuADMetric(MetricBase):
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['noAns-EM'] evaluate_result['EM'] += evaluate_result['noAns-EM']
flag += 1 flag += 1
if self.has_ans_correct + self.has_ans_wrong > 0: if self.has_ans_correct + self.has_ans_wrong > 0:
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ evaluate_result[f'hasAns-f_{self.f_beta}'] = \
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3)
@@ -827,32 +836,31 @@ class SQuADMetric(MetricBase):
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['hasAns-EM'] evaluate_result['EM'] += evaluate_result['hasAns-EM']
flag += 1 flag += 1
if self.print_predict_stat: if self.print_predict_stat:
evaluate_result['no2no'] = self.no2no evaluate_result['no2no'] = self.no2no
evaluate_result['no2yes'] = self.no2yes evaluate_result['no2yes'] = self.no2yes
evaluate_result['yes2no'] = self.yes2no evaluate_result['yes2no'] = self.yes2no
evaluate_result['yes2yes'] = self.yes2yes evaluate_result['yes2yes'] = self.yes2yes
if flag <= 0: if flag <= 0:
return evaluate_result return evaluate_result
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3)
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3)
if reset: if reset:
self.no_ans_correct = 0 self.no_ans_correct = 0
self.no_ans_wrong = 0 self.no_ans_wrong = 0
self.has_ans_correct = 0 self.has_ans_correct = 0
self.has_ans_wrong = 0 self.has_ans_wrong = 0
self.has_ans_f = 0. self.has_ans_f = 0.
self.no2no = 0 self.no2no = 0
self.no2yes = 0 self.no2yes = 0
self.yes2no = 0 self.yes2no = 0
self.yes2yes = 0 self.yes2yes = 0
return evaluate_result return evaluate_result


+ 14
- 6
fastNLP/core/optimizer.py View File

@@ -2,6 +2,12 @@
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。


""" """
__all__ = [
"Optimizer",
"SGD",
"Adam"
]

import torch import torch




@@ -12,15 +18,16 @@ class Optimizer(object):
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
:param kwargs: additional parameters. :param kwargs: additional parameters.
""" """
def __init__(self, model_params, **kwargs): def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"): if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params self.model_params = model_params
self.settings = kwargs self.settings = kwargs
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
raise NotImplementedError raise NotImplementedError
def _get_require_grads_param(self, params): def _get_require_grads_param(self, params):
""" """
将params中不需要gradient的删除 将params中不需要gradient的删除
@@ -29,6 +36,7 @@ class Optimizer(object):
""" """
return [param for param in params if param.requires_grad] return [param for param in params if param.requires_grad]



class SGD(Optimizer): class SGD(Optimizer):
""" """
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` 别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD`
@@ -37,12 +45,12 @@ class SGD(Optimizer):
:param float momentum: momentum. Default: 0 :param float momentum: momentum. Default: 0
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
""" """
def __init__(self, lr=0.001, momentum=0, model_params=None): def __init__(self, lr=0.001, momentum=0, model_params=None):
if not isinstance(lr, float): if not isinstance(lr, float):
raise TypeError("learning rate has to be float.") raise TypeError("learning rate has to be float.")
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
if self.model_params is None: if self.model_params is None:
# careful! generator cannot be assigned. # careful! generator cannot be assigned.
@@ -59,13 +67,13 @@ class Adam(Optimizer):
:param float weight_decay: :param float weight_decay:
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
""" """
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
if not isinstance(lr, float): if not isinstance(lr, float):
raise TypeError("learning rate has to be float.") raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
weight_decay=weight_decay) weight_decay=weight_decay)
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
if self.model_params is None: if self.model_params is None:
# careful! generator cannot be assigned. # careful! generator cannot be assigned.


+ 6
- 1
fastNLP/core/predictor.py View File

@@ -1,3 +1,7 @@
"""
..todo::
检查这个类是否需要
"""
from collections import defaultdict from collections import defaultdict


import torch import torch
@@ -9,7 +13,8 @@ from .utils import _build_args




class Predictor(object): class Predictor(object):
"""An interface for predicting outputs based on trained models.
"""
An interface for predicting outputs based on trained models.


It does not care about evaluations of the model, which is different from Tester. It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP. This is a high-level model wrapper to be called by FastNLP.


+ 7
- 3
fastNLP/core/sampler.py View File

@@ -1,9 +1,13 @@
""" """
sampler 子类实现了 fastNLP 所需的各种采样器。 sampler 子类实现了 fastNLP 所需的各种采样器。


""" """
__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"]
__all__ = [
"Sampler",
"BucketSampler",
"SequentialSampler",
"RandomSampler"
]

from itertools import chain from itertools import chain


import numpy as np import numpy as np


+ 16
- 12
fastNLP/core/tester.py View File

@@ -35,7 +35,7 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation
import warnings import warnings


import torch import torch
from torch import nn
import torch.nn as nn


from .batch import Batch from .batch import Batch
from .dataset import DataSet from .dataset import DataSet
@@ -49,6 +49,10 @@ from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device


__all__ = [
"Tester"
]



class Tester(object): class Tester(object):
""" """
@@ -77,29 +81,29 @@ class Tester(object):
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
""" """
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1):
super(Tester, self).__init__() super(Tester, self).__init__()
if not isinstance(data, DataSet): if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
self.metrics = _prepare_metrics(metrics) self.metrics = _prepare_metrics(metrics)
self.data = data self.data = data
self._model = _move_model_to_device(model, device=device) self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose
# 如果是DataParallel将没有办法使用predict方法 # 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel): if isinstance(self._model, nn.DataParallel):
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'):
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function,"
" while DataParallel has no predict() function.") " while DataParallel has no predict() function.")
self._model = self._model.module self._model = self._model.module
# check predict # check predict
if hasattr(self._model, 'predict'): if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict self._predict_func = self._model.predict
@@ -109,7 +113,7 @@ class Tester(object):
f"for evaluation, not `{type(self._predict_func)}`.") f"for evaluation, not `{type(self._predict_func)}`.")
else: else:
self._predict_func = self._model.forward self._predict_func = self._model.forward
def test(self): def test(self):
"""开始进行验证,并返回验证结果。 """开始进行验证,并返回验证结果。


@@ -144,12 +148,12 @@ class Tester(object):
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0) dataset=self.data, check_level=0)
if self.verbose >= 1: if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results))) print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False) self._mode(network, is_test=False)
return eval_results return eval_results
def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


@@ -161,13 +165,13 @@ class Tester(object):
model.eval() model.eval()
else: else:
model.train() model.train()
def _data_forward(self, func, x): def _data_forward(self, func, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(func, **x) x = _build_args(func, **x)
y = func(**x) y = func(**x)
return y return y
def _format_eval_results(self, results): def _format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.




+ 7
- 5
fastNLP/core/trainer.py View File

@@ -295,15 +295,17 @@ Example2.3
fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。


""" """
__all__ = [
"Trainer"
]


import os import os
import time import time
from datetime import datetime
from datetime import timedelta
from datetime import datetime, timedelta


import numpy as np import numpy as np
import torch import torch
from torch import nn
import torch.nn as nn


try: try:
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -315,6 +317,7 @@ from .callback import CallbackManager, CallbackException
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .metrics import _prepare_metrics from .metrics import _prepare_metrics
from .optimizer import Optimizer
from .sampler import Sampler from .sampler import Sampler
from .sampler import RandomSampler from .sampler import RandomSampler
from .sampler import SequentialSampler from .sampler import SequentialSampler
@@ -326,7 +329,6 @@ from .utils import _check_loss_evaluate
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .optimizer import Optimizer
from .utils import _move_model_to_device from .utils import _move_model_to_device




@@ -464,7 +466,7 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs len(self.train_data) % self.batch_size != 0)) * self.n_epochs
self.model = _move_model_to_device(self.model, device=device) self.model = _move_model_to_device(self.model, device=device)
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
elif isinstance(optimizer, Optimizer): elif isinstance(optimizer, Optimizer):


+ 86
- 72
fastNLP/core/utils.py View File

@@ -1,20 +1,25 @@
""" """
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
""" """
__all__ = ["cache_results", "seq_len_to_mask"]
__all__ = [
"cache_results",
"seq_len_to_mask"
]

import _pickle import _pickle
import inspect import inspect
import os import os
import warnings import warnings
from collections import Counter
from collections import namedtuple
from collections import Counter, namedtuple


import numpy as np import numpy as np
import torch import torch
from torch import nn
import torch.nn as nn



_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
'varargs'])



def _prepare_cache_filepath(filepath): def _prepare_cache_filepath(filepath):
""" """
@@ -40,26 +45,28 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
import time import time
import numpy as np import numpy as np
from fastNLP import cache_results from fastNLP import cache_results
@cache_results('cache.pkl') @cache_results('cache.pkl')
def process_data(): def process_data():
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1) time.sleep(1)
return np.random.randint(5, size=(10, 20))
return np.random.randint(10, size=(5,))
start_time = time.time() start_time = time.time()
process_data()
print("res =",process_data())
print(time.time() - start_time) print(time.time() - start_time)
start_time = time.time() start_time = time.time()
process_data()
print("res =",process_data())
print(time.time() - start_time) print(time.time() - start_time)

# 输出内容如下
# Save cache to cache.pkl.
# 1.0015439987182617
# Read cache from cache.pkl.
# 0.00013065338134765625
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# res = [5 4 9 1 8]
# 0.0040721893310546875


可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理


@@ -83,11 +90,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
:param int _verbose: 是否打印cache的信息。 :param int _verbose: 是否打印cache的信息。
:return: :return:
""" """
def wrapper_(func): def wrapper_(func):
signature = inspect.signature(func) signature = inspect.signature(func)
for key, _ in signature.parameters.items(): for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'): if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_cache_fp' in kwargs: if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp') cache_filepath = kwargs.pop('_cache_fp')
@@ -95,7 +104,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else: else:
cache_filepath = _cache_fp cache_filepath = _cache_fp
if '_refresh' in kwargs: if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool." assert isinstance(refresh, bool), "_refresh can only be bool."
else: else:
refresh = _refresh refresh = _refresh
@@ -105,16 +114,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else: else:
verbose = _verbose verbose = _verbose
refresh_flag = True refresh_flag = True
if cache_filepath is not None and refresh is False: if cache_filepath is not None and refresh is False:
# load data # load data
if os.path.exists(cache_filepath): if os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f: with open(cache_filepath, 'rb') as f:
results = _pickle.load(f) results = _pickle.load(f)
if verbose==1:
if verbose == 1:
print("Read cache from {}.".format(cache_filepath)) print("Read cache from {}.".format(cache_filepath))
refresh_flag = False refresh_flag = False
if refresh_flag: if refresh_flag:
results = func(*args, **kwargs) results = func(*args, **kwargs)
if cache_filepath is not None: if cache_filepath is not None:
@@ -124,11 +133,14 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'wb') as f: with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f) _pickle.dump(results, f)
print("Save cache to {}.".format(cache_filepath)) print("Save cache to {}.".format(cache_filepath))
return results return results
return wrapper return wrapper
return wrapper_ return wrapper_



# def save_pickle(obj, pickle_path, file_name): # def save_pickle(obj, pickle_path, file_name):
# """Save an object into a pickle file. # """Save an object into a pickle file.
# #
@@ -196,7 +208,7 @@ def _move_model_to_device(model, device):
""" """
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")
if device is None: if device is None:
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
model.cuda() model.cuda()
@@ -205,34 +217,35 @@ def _move_model_to_device(model, device):
if not torch.cuda.is_available() and ( if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
if isinstance(device, int): if isinstance(device, int):
assert device>-1, "device can only be non-negative integer"
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
device)
assert device > -1, "device can only be non-negative integer"
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format(
torch.cuda.device_count(),
device)
device = torch.device('cuda:{}'.format(device)) device = torch.device('cuda:{}'.format(device))
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
if device.type == 'cuda' and device.index is not None: if device.type == 'cuda' and device.index is not None:
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
elif isinstance(device, torch.device): elif isinstance(device, torch.device):
if device.type == 'cuda' and device.index is not None: if device.type == 'cuda' and device.index is not None:
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
elif isinstance(device, list): elif isinstance(device, list):
types = set([type(d) for d in device]) types = set([type(d) for d in device])
assert len(types)==1, "Mixed type in device, only `int` allowed."
assert len(types) == 1, "Mixed type in device, only `int` allowed."
assert list(types)[0] == int, "Only int supported for multiple devices." assert list(types)[0] == int, "Only int supported for multiple devices."
assert len(set(device))==len(device), "Duplicated device id found in device."
assert len(set(device)) == len(device), "Duplicated device id found in device."
for d in device: for d in device:
assert d>-1, "Only non-negative device id allowed."
if len(device)>1:
assert d > -1, "Only non-negative device id allowed."
if len(device) > 1:
output_device = device[0] output_device = device[0]
model = nn.DataParallel(model, device_ids=device, output_device=output_device) model = nn.DataParallel(model, device_ids=device, output_device=output_device)
device = torch.device(device[0]) device = torch.device(device[0])
@@ -250,9 +263,9 @@ def _get_model_device(model):
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。
""" """
assert isinstance(model, nn.Module) assert isinstance(model, nn.Module)
parameters = list(model.parameters()) parameters = list(model.parameters())
if len(parameters)==0:
if len(parameters) == 0:
return None return None
else: else:
return parameters[0].device return parameters[0].device
@@ -407,7 +420,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
if not isinstance(device, torch.device): if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
for key, value in arg.items(): for key, value in arg.items():
@@ -422,10 +435,10 @@ class _CheckError(Exception):


_CheckError. Used in losses.LossBase, metrics.MetricBase. _CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res: _CheckRes, func_signature: str): def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`'] errs = [f'Problems occurred when calling `{func_signature}`']
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
@@ -434,9 +447,9 @@ class _CheckError(Exception):
errs.append(f"\tduplicated param: {check_res.duplicated}") errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused: if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}") errs.append(f"\tunused param: {check_res.unused}")
Exception.__init__(self, '\n'.join(errs)) Exception.__init__(self, '\n'.join(errs))
self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature self.func_signature = func_signature


@@ -456,7 +469,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}") # errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused: if check_res.unused:
for _unused in check_res.unused: for _unused in check_res.unused:
if _unused in target_dict: if _unused in target_dict:
@@ -466,8 +479,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if _unused_field: if _unused_field:
unuseds.append(f"\tunused field: {_unused_field}") unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param: if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = func_signature.split('.')[0] module_name = func_signature.split('.')[0]
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
@@ -488,14 +501,14 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
mapped_missing.append(_miss) mapped_missing.append(_miss)
else: else:
unmapped_missing.append(_miss) unmapped_missing.append(_miss)
for _miss in mapped_missing + unmapped_missing: for _miss in mapped_missing + unmapped_missing:
if _miss in dataset: if _miss in dataset:
suggestions.append(f"Set `{_miss}` as target.") suggestions.append(f"Set `{_miss}` as target.")
else: else:
_tmp = '' _tmp = ''
if check_res.unused: if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}."
if _tmp: if _tmp:
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
else: else:
@@ -513,25 +526,25 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# else: # else:
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.'
# suggestions.append(_tmp) # suggestions.append(_tmp)
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if len(errs)>0:
if len(errs) > 0:
errs.extend(unuseds) errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
if idx>0:
if idx > 0:
sugg_str += '\t\t\t' sugg_str += '\t\t\t'
sugg_str += f'({idx+1}). {sugg}\n'
sugg_str += f'({idx + 1}). {sugg}\n'
sugg_str = sugg_str[:-1] sugg_str = sugg_str[:-1]
else: else:
sugg_str += suggestions[0] sugg_str += suggestions[0]
@@ -546,14 +559,15 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_unused_warn = f'{check_res.unused} is not used by {module_name}.' _unused_warn = f'{check_res.unused} is not used by {module_name}.'
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)



def _check_forward_error(forward_func, batch_x, dataset, check_level): def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x) check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = _get_func_signature(forward_func) func_signature = _get_func_signature(forward_func)
errs = [] errs = []
suggestions = [] suggestions = []
_unused = [] _unused = []
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}") # errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
@@ -574,20 +588,20 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`." # f"rename the field in `unused field:`."
suggestions.append(_tmp) suggestions.append(_tmp)
if check_res.unused: if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"] _unused = [f"\tunused field: {check_res.unused}"]
if len(errs)>0:
if len(errs) > 0:
errs.extend(_unused) errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
sugg_str += f'({idx + 1}). {sugg}'
else: else:
sugg_str += suggestions[0] sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
@@ -622,8 +636,8 @@ def seq_len_to_mask(seq_len):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(seq_len.max()) max_len = int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len<seq_len.reshape(-1, 1)
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
elif isinstance(seq_len, torch.Tensor): elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0) batch_size = seq_len.size(0)
@@ -632,7 +646,7 @@ def seq_len_to_mask(seq_len):
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else: else:
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")
return mask return mask




@@ -640,24 +654,24 @@ class _pseudo_tqdm:
""" """
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass pass
def write(self, info): def write(self, info):
print(info) print(info)
def set_postfix_str(self, info): def set_postfix_str(self, info):
print(info) print(info)
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass
return pass_func return pass_func
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
del self del self

+ 16
- 0
fastNLP/core/vocabulary.py View File

@@ -1,5 +1,10 @@
__all__ = [
"Vocabulary"
]

from functools import wraps from functools import wraps
from collections import Counter from collections import Counter

from .dataset import DataSet from .dataset import DataSet




@@ -318,6 +323,17 @@ class Vocabulary(object):
""" """
return self.idx2word[idx] return self.idx2word[idx]
def clear(self):
"""
删除Vocabulary中的词表数据。相当于重新初始化一下。

:return:
"""
self.word_count.clear()
self.word2idx = None
self.idx2word = None
self.rebuild = True
def __getstate__(self): def __getstate__(self):
"""Use to prepare data for pickle. """Use to prepare data for pickle.




+ 2
- 1
fastNLP/io/__init__.py View File

@@ -24,7 +24,8 @@ __all__ = [
'ModelLoader', 'ModelLoader',
'ModelSaver', 'ModelSaver',
] ]

from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader as ModelLoader, ModelSaver as ModelSaver
from .model_io import ModelLoader, ModelSaver

+ 13
- 6
fastNLP/io/base_loader.py View File

@@ -1,3 +1,7 @@
__all__ = [
"BaseLoader"
]

import _pickle as pickle import _pickle as pickle
import os import os


@@ -7,9 +11,10 @@ class BaseLoader(object):
各个 Loader 的基类,提供了 API 的参考。 各个 Loader 的基类,提供了 API 的参考。


""" """
def __init__(self): def __init__(self):
super(BaseLoader, self).__init__() super(BaseLoader, self).__init__()
@staticmethod @staticmethod
def load_lines(data_path): def load_lines(data_path):
""" """
@@ -20,7 +25,7 @@ class BaseLoader(object):
with open(data_path, "r", encoding="utf=8") as f: with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines() text = f.readlines()
return [line.strip() for line in text] return [line.strip() for line in text]
@classmethod @classmethod
def load(cls, data_path): def load(cls, data_path):
""" """
@@ -31,7 +36,7 @@ class BaseLoader(object):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines() text = f.readlines()
return [[word for word in sent.strip()] for sent in text] return [[word for word in sent.strip()] for sent in text]
@classmethod @classmethod
def load_with_cache(cls, data_path, cache_path): def load_with_cache(cls, data_path, cache_path):
"""缓存版的load """缓存版的load
@@ -48,16 +53,18 @@ class BaseLoader(object):


class DataLoaderRegister: class DataLoaderRegister:
_readers = {} _readers = {}
@classmethod @classmethod
def set_reader(cls, reader_cls, read_fn_name): def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls): # def wrapper(reader_cls):
if read_fn_name in cls._readers: if read_fn_name in cls._readers:
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name))
raise KeyError(
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls,
read_fn_name))
if hasattr(reader_cls, 'load'): if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load cls._readers[read_fn_name] = reader_cls().load
return reader_cls return reader_cls
@classmethod @classmethod
def get_reader(cls, read_fn_name): def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers: if read_fn_name in cls._readers:


+ 36
- 28
fastNLP/io/config_io.py View File

@@ -1,8 +1,14 @@
""" """

用于读入和处理和保存 config 文件 用于读入和处理和保存 config 文件
.. todo::
这个模块中的类可能被抛弃?
""" """
__all__ = ["ConfigLoader","ConfigSection","ConfigSaver"]
__all__ = [
"ConfigLoader",
"ConfigSection",
"ConfigSaver"
]

import configparser import configparser
import json import json
import os import os
@@ -19,15 +25,16 @@ class ConfigLoader(BaseLoader):
:param str data_path: 配置文件的路径 :param str data_path: 配置文件的路径


""" """
def __init__(self, data_path=None): def __init__(self, data_path=None):
super(ConfigLoader, self).__init__() super(ConfigLoader, self).__init__()
if data_path is not None: if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path)) self.config = self.parse(super(ConfigLoader, self).load(data_path))
@staticmethod @staticmethod
def parse(string): def parse(string):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def load_config(file_path, sections): def load_config(file_path, sections):
""" """
@@ -81,10 +88,10 @@ class ConfigSection(object):
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用


""" """
def __init__(self): def __init__(self):
super(ConfigSection, self).__init__() super(ConfigSection, self).__init__()
def __getitem__(self, key): def __getitem__(self, key):
""" """
:param key: str, the name of the attribute :param key: str, the name of the attribute
@@ -97,7 +104,7 @@ class ConfigSection(object):
if key in self.__dict__.keys(): if key in self.__dict__.keys():
return getattr(self, key) return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key) raise AttributeError("do NOT have attribute %s" % key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
""" """
:param key: str, the name of the attribute :param key: str, the name of the attribute
@@ -112,14 +119,14 @@ class ConfigSection(object):
raise AttributeError("attr %s except %s but got %s" % raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value)))) (key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value) setattr(self, key, value)
def __contains__(self, item): def __contains__(self, item):
""" """
:param item: The key of item. :param item: The key of item.
:return: True if the key in self.__dict__.keys() else False. :return: True if the key in self.__dict__.keys() else False.
""" """
return item in self.__dict__.keys() return item in self.__dict__.keys()
def __eq__(self, other): def __eq__(self, other):
"""Overwrite the == operator """Overwrite the == operator


@@ -131,15 +138,15 @@ class ConfigSection(object):
return False return False
if getattr(self, k) != getattr(self, k): if getattr(self, k) != getattr(self, k):
return False return False
for k in other.__dict__.keys(): for k in other.__dict__.keys():
if k not in self.__dict__.keys(): if k not in self.__dict__.keys():
return False return False
if getattr(self, k) != getattr(self, k): if getattr(self, k) != getattr(self, k):
return False return False
return True return True
def __ne__(self, other): def __ne__(self, other):
"""Overwrite the != operator """Overwrite the != operator


@@ -147,7 +154,7 @@ class ConfigSection(object):
:return: :return:
""" """
return not self.__eq__(other) return not self.__eq__(other)
@property @property
def data(self): def data(self):
return self.__dict__ return self.__dict__
@@ -162,11 +169,12 @@ class ConfigSaver(object):
:param str file_path: 配置文件的路径 :param str file_path: 配置文件的路径


""" """
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
if not os.path.exists(self.file_path): if not os.path.exists(self.file_path):
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))
def _get_section(self, sect_name): def _get_section(self, sect_name):
""" """
This is the function to get the section with the section name. This is the function to get the section with the section name.
@@ -177,7 +185,7 @@ class ConfigSaver(object):
sect = ConfigSection() sect = ConfigSection()
ConfigLoader().load_config(self.file_path, {sect_name: sect}) ConfigLoader().load_config(self.file_path, {sect_name: sect})
return sect return sect
def _read_section(self): def _read_section(self):
""" """
This is the function to read sections from the config file. This is the function to read sections from the config file.
@@ -187,16 +195,16 @@ class ConfigSaver(object):
sect_key_list: A list of names in sect_list. sect_key_list: A list of names in sect_list.
""" """
sect_name = None sect_name = None
sect_list = {} sect_list = {}
sect_key_list = [] sect_key_list = []
single_section = {} single_section = {}
single_section_key = [] single_section_key = []
with open(self.file_path, 'r') as f: with open(self.file_path, 'r') as f:
lines = f.readlines() lines = f.readlines()
for line in lines: for line in lines:
if line.startswith('[') and line.endswith(']\n'): if line.startswith('[') and line.endswith(']\n'):
if sect_name is None: if sect_name is None:
@@ -208,29 +216,29 @@ class ConfigSaver(object):
sect_key_list.append(sect_name) sect_key_list.append(sect_name)
sect_name = line[1: -2] sect_name = line[1: -2]
continue continue
if line.startswith('#'): if line.startswith('#'):
single_section[line] = '#' single_section[line] = '#'
single_section_key.append(line) single_section_key.append(line)
continue continue
if line.startswith('\n'): if line.startswith('\n'):
single_section_key.append('\n') single_section_key.append('\n')
continue continue
if '=' not in line: if '=' not in line:
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) raise RuntimeError("can NOT load config file {}".__format__(self.file_path))
key = line.split('=', maxsplit=1)[0].strip() key = line.split('=', maxsplit=1)[0].strip()
value = line.split('=', maxsplit=1)[1].strip() + '\n' value = line.split('=', maxsplit=1)[1].strip() + '\n'
single_section[key] = value single_section[key] = value
single_section_key.append(key) single_section_key.append(key)
if sect_name is not None: if sect_name is not None:
sect_list[sect_name] = single_section, single_section_key sect_list[sect_name] = single_section, single_section_key
sect_key_list.append(sect_name) sect_key_list.append(sect_name)
return sect_list, sect_key_list return sect_list, sect_key_list
def _write_section(self, sect_list, sect_key_list): def _write_section(self, sect_list, sect_key_list):
""" """
This is the function to write config file with section list and name list. This is the function to write config file with section list and name list.
@@ -252,7 +260,7 @@ class ConfigSaver(object):
continue continue
f.write(key + ' = ' + single_section[key]) f.write(key + ' = ' + single_section[key])
f.write('\n') f.write('\n')
def save_config_file(self, section_name, section): def save_config_file(self, section_name, section):
""" """
这个方法可以用来修改并保存配置文件中单独的一个 section 这个方法可以用来修改并保存配置文件中单独的一个 section
@@ -284,11 +292,11 @@ class ConfigSaver(object):
break break
if not change_file: if not change_file:
return return
sect_list, sect_key_list = self._read_section() sect_list, sect_key_list = self._read_section()
if section_name not in sect_key_list: if section_name not in sect_key_list:
raise AttributeError() raise AttributeError()
sect, sect_key = sect_list[section_name] sect, sect_key = sect_list[section_name]
for k in section.__dict__.keys(): for k in section.__dict__.keys():
if k not in sect_key: if k not in sect_key:


+ 1
- 0
fastNLP/io/dataset_loader.py View File

@@ -20,6 +20,7 @@ __all__ = [
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
] ]

from nltk.tree import Tree from nltk.tree import Tree


from ..core.dataset import DataSet from ..core.dataset import DataSet


+ 30
- 26
fastNLP/io/embed_loader.py View File

@@ -1,11 +1,15 @@
__all__ = [
"EmbedLoader"
]

import os import os
import warnings


import numpy as np import numpy as np


from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader from .base_loader import BaseLoader


import warnings


class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
""" """
@@ -13,10 +17,10 @@ class EmbedLoader(BaseLoader):


用于读取预训练的embedding, 读取结果可直接载入为模型参数。 用于读取预训练的embedding, 读取结果可直接载入为模型参数。
""" """
def __init__(self): def __init__(self):
super(EmbedLoader, self).__init__() super(EmbedLoader, self).__init__()
@staticmethod @staticmethod
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'):
""" """
@@ -40,11 +44,11 @@ class EmbedLoader(BaseLoader):
line = f.readline().strip() line = f.readline().strip()
parts = line.split() parts = line.split()
start_idx = 0 start_idx = 0
if len(parts)==2:
if len(parts) == 2:
dim = int(parts[1]) dim = int(parts[1])
start_idx += 1 start_idx += 1
else: else:
dim = len(parts)-1
dim = len(parts) - 1
f.seek(0) f.seek(0)
matrix = np.random.randn(len(vocab), dim).astype(dtype) matrix = np.random.randn(len(vocab), dim).astype(dtype)
for idx, line in enumerate(f, start_idx): for idx, line in enumerate(f, start_idx):
@@ -63,21 +67,21 @@ class EmbedLoader(BaseLoader):
total_hits = sum(hit_flags) total_hits = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
found_vectors = matrix[hit_flags] found_vectors = matrix[hit_flags]
if len(found_vectors)!=0:
if len(found_vectors) != 0:
mean = np.mean(found_vectors, axis=0, keepdims=True) mean = np.mean(found_vectors, axis=0, keepdims=True)
std = np.std(found_vectors, axis=0, keepdims=True) std = np.std(found_vectors, axis=0, keepdims=True)
unfound_vec_num = len(vocab) - total_hits unfound_vec_num = len(vocab) - total_hits
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean
matrix[hit_flags==False] = r_vecs
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean
matrix[hit_flags == False] = r_vecs
if normalize: if normalize:
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) matrix /= np.linalg.norm(matrix, axis=1, keepdims=True)
return matrix return matrix
@staticmethod @staticmethod
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
error='ignore'):
error='ignore'):
""" """
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。


@@ -96,35 +100,35 @@ class EmbedLoader(BaseLoader):
vec_dict = {} vec_dict = {}
found_unknown = False found_unknown = False
found_pad = False found_pad = False
with open(embed_filepath, 'r', encoding='utf-8') as f: with open(embed_filepath, 'r', encoding='utf-8') as f:
line = f.readline() line = f.readline()
start = 1 start = 1
dim = -1 dim = -1
if len(line.strip().split())!=2:
if len(line.strip().split()) != 2:
f.seek(0) f.seek(0)
start = 0 start = 0
for idx, line in enumerate(f, start=start): for idx, line in enumerate(f, start=start):
try: try:
parts = line.strip().split() parts = line.strip().split()
word = parts[0] word = parts[0]
if dim==-1:
dim = len(parts)-1
if dim == -1:
dim = len(parts) - 1
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
vec_dict[word] = vec vec_dict[word] = vec
vocab.add_word(word) vocab.add_word(word)
if unknown is not None and unknown==word:
if unknown is not None and unknown == word:
found_unknown = True found_unknown = True
if found_pad is not None and padding==word:
if found_pad is not None and padding == word:
found_pad = True found_pad = True
except Exception as e: except Exception as e:
if error=='ignore':
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx)) warnings.warn("Error occurred at the {} line.".format(idx))
pass pass
else: else:
print("Error occurred at the {} line.".format(idx)) print("Error occurred at the {} line.".format(idx))
raise e raise e
if dim==-1:
if dim == -1:
raise RuntimeError("{} is an empty file.".format(embed_filepath)) raise RuntimeError("{} is an empty file.".format(embed_filepath))
matrix = np.random.randn(len(vocab), dim).astype(dtype) matrix = np.random.randn(len(vocab), dim).astype(dtype)
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): if (unknown is not None and not found_unknown) or (padding is not None and not found_pad):
@@ -133,19 +137,19 @@ class EmbedLoader(BaseLoader):
start_idx += 1 start_idx += 1
if unknown is not None: if unknown is not None:
start_idx += 1 start_idx += 1
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) mean = np.mean(matrix[start_idx:], axis=0, keepdims=True)
std = np.std(matrix[start_idx:], axis=0, keepdims=True) std = np.std(matrix[start_idx:], axis=0, keepdims=True)
if (unknown is not None and not found_unknown): if (unknown is not None and not found_unknown):
matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean
matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean
if (padding is not None and not found_pad): if (padding is not None and not found_pad):
matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean
for key, vec in vec_dict.items(): for key, vec in vec_dict.items():
index = vocab.to_index(key) index = vocab.to_index(key)
matrix[index] = vec matrix[index] = vec
if normalize: if normalize:
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) matrix /= np.linalg.norm(matrix, axis=1, keepdims=True)
return matrix, vocab return matrix, vocab

+ 10
- 5
fastNLP/io/model_io.py View File

@@ -1,6 +1,11 @@
""" """
用于载入和保存模型 用于载入和保存模型
""" """
__all__ = [
"ModelLoader",
"ModelSaver"
]

import torch import torch


from .base_loader import BaseLoader from .base_loader import BaseLoader
@@ -12,10 +17,10 @@ class ModelLoader(BaseLoader):


用于读取模型 用于读取模型
""" """
def __init__(self): def __init__(self):
super(ModelLoader, self).__init__() super(ModelLoader, self).__init__()
@staticmethod @staticmethod
def load_pytorch(empty_model, model_path): def load_pytorch(empty_model, model_path):
""" """
@@ -25,7 +30,7 @@ class ModelLoader(BaseLoader):
:param str model_path: 模型保存的路径 :param str model_path: 模型保存的路径
""" """
empty_model.load_state_dict(torch.load(model_path)) empty_model.load_state_dict(torch.load(model_path))
@staticmethod @staticmethod
def load_pytorch_model(model_path): def load_pytorch_model(model_path):
""" """
@@ -48,14 +53,14 @@ class ModelSaver(object):
saver.save_pytorch(model) saver.save_pytorch(model)


""" """
def __init__(self, save_path): def __init__(self, save_path):
""" """


:param save_path: 模型保存的路径 :param save_path: 模型保存的路径
""" """
self.save_path = save_path self.save_path = save_path
def save_pytorch(self, model, param_only=True): def save_pytorch(self, model, param_only=True):
""" """
把 PyTorch 模型存入 ".pkl" 文件 把 PyTorch 模型存入 ".pkl" 文件


+ 18
- 2
fastNLP/models/__init__.py View File

@@ -7,7 +7,23 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models




""" """
__all__ = ["CNNText", "SeqLabeling", "ESIM", "STSeqLabel", "AdvSeqLabel", "STNLICls", "STSeqCls"]
__all__ = [
"CNNText",
"SeqLabeling",
"AdvSeqLabel",
"ESIM",
"StarTransEnc",
"STSeqLabel",
"STNLICls",
"STSeqCls",
"BiaffineParser",
"GraphParser"
]

from .base_model import BaseModel from .base_model import BaseModel
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \
BertForTokenClassification BertForTokenClassification
@@ -15,4 +31,4 @@ from .biaffine_parser import BiaffineParser, GraphParser
from .cnn_text_classification import CNNText from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel from .sequence_labeling import SeqLabeling, AdvSeqLabel
from .snli import ESIM from .snli import ESIM
from .star_transformer import STSeqCls, STNLICls, STSeqLabel
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel

+ 6
- 6
fastNLP/models/base_model.py View File

@@ -1,18 +1,18 @@
import torch import torch


from ..modules.decoder.MLP import MLP
from ..modules.decoder.mlp import MLP




class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models. """Base PyTorch model for all models.
""" """
def __init__(self): def __init__(self):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
def fit(self, train_data, dev_data=None, **train_args): def fit(self, train_data, dev_data=None, **train_args):
pass pass
def predict(self, *args, **kwargs): def predict(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


@@ -21,9 +21,9 @@ class NaiveClassifier(BaseModel):
def __init__(self, in_feature_dim, out_feature_dim): def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier, self).__init__() super(NaiveClassifier, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])
def forward(self, x): def forward(self, x):
return {"predict": torch.sigmoid(self.mlp(x))} return {"predict": torch.sigmoid(self.mlp(x))}
def predict(self, x): def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

+ 87
- 70
fastNLP/models/biaffine_parser.py View File

@@ -1,11 +1,17 @@
"""Biaffine Dependency Parser 的 Pytorch 实现.
""" """
from collections import defaultdict
Biaffine Dependency Parser 的 Pytorch 实现.
"""
__all__ = [
"BiaffineParser",
"GraphParser"
]


import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.nn import functional as F
import torch.nn as nn
import torch.nn.functional as F

from collections import defaultdict


from ..core.const import Const as C from ..core.const import Const as C
from ..core.losses import LossFunc from ..core.losses import LossFunc
@@ -18,6 +24,7 @@ from ..modules.utils import get_embeddings
from .base_model import BaseModel from .base_model import BaseModel
from ..core.utils import seq_len_to_mask from ..core.utils import seq_len_to_mask



def _mst(scores): def _mst(scores):
""" """
with some modification to support parser output for MST decoding with some modification to support parser output for MST decoding
@@ -44,7 +51,7 @@ def _mst(scores):
scores[roots, new_heads] / root_scores)] scores[roots, new_heads] / root_scores)]
heads[roots] = new_heads heads[roots] = new_heads
heads[new_root] = 0 heads[new_root] = 0
edges = defaultdict(set) edges = defaultdict(set)
vertices = set((0,)) vertices = set((0,))
for dep, head in enumerate(heads[tokens]): for dep, head in enumerate(heads[tokens]):
@@ -73,7 +80,7 @@ def _mst(scores):
heads[changed_cycle] = new_head heads[changed_cycle] = new_head
edges[new_head].add(changed_cycle) edges[new_head].add(changed_cycle)
edges[old_head].remove(changed_cycle) edges[old_head].remove(changed_cycle)
return heads return heads




@@ -88,7 +95,7 @@ def _find_cycle(vertices, edges):
_lowlinks = {} _lowlinks = {}
_onstack = defaultdict(lambda: False) _onstack = defaultdict(lambda: False)
_SCCs = [] _SCCs = []
def _strongconnect(v): def _strongconnect(v):
nonlocal _index nonlocal _index
_indices[v] = _index _indices[v] = _index
@@ -96,28 +103,28 @@ def _find_cycle(vertices, edges):
_index += 1 _index += 1
_stack.append(v) _stack.append(v)
_onstack[v] = True _onstack[v] = True
for w in edges[v]: for w in edges[v]:
if w not in _indices: if w not in _indices:
_strongconnect(w) _strongconnect(w)
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) _lowlinks[v] = min(_lowlinks[v], _lowlinks[w])
elif _onstack[w]: elif _onstack[w]:
_lowlinks[v] = min(_lowlinks[v], _indices[w]) _lowlinks[v] = min(_lowlinks[v], _indices[w])
if _lowlinks[v] == _indices[v]: if _lowlinks[v] == _indices[v]:
SCC = set() SCC = set()
while True: while True:
w = _stack.pop() w = _stack.pop()
_onstack[w] = False _onstack[w] = False
SCC.add(w) SCC.add(w)
if not(w != v):
if not (w != v):
break break
_SCCs.append(SCC) _SCCs.append(SCC)
for v in vertices: for v in vertices:
if v not in _indices: if v not in _indices:
_strongconnect(v) _strongconnect(v)
return [SCC for SCC in _SCCs if len(SCC) > 1] return [SCC for SCC in _SCCs if len(SCC) > 1]




@@ -125,9 +132,10 @@ class GraphParser(BaseModel):
""" """
基于图的parser base class, 支持贪婪解码和最大生成树解码 基于图的parser base class, 支持贪婪解码和最大生成树解码
""" """
def __init__(self): def __init__(self):
super(GraphParser, self).__init__() super(GraphParser, self).__init__()
@staticmethod @staticmethod
def greedy_decoder(arc_matrix, mask=None): def greedy_decoder(arc_matrix, mask=None):
""" """
@@ -146,7 +154,7 @@ class GraphParser(BaseModel):
if mask is not None: if mask is not None:
heads *= mask.long() heads *= mask.long()
return heads return heads
@staticmethod @staticmethod
def mst_decoder(arc_matrix, mask=None): def mst_decoder(arc_matrix, mask=None):
""" """
@@ -176,6 +184,7 @@ class ArcBiaffine(nn.Module):
:param hidden_size: 输入的特征维度 :param hidden_size: 输入的特征维度
:param bias: 是否使用bias. Default: ``True`` :param bias: 是否使用bias. Default: ``True``
""" """
def __init__(self, hidden_size, bias=True): def __init__(self, hidden_size, bias=True):
super(ArcBiaffine, self).__init__() super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
@@ -185,7 +194,7 @@ class ArcBiaffine(nn.Module):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
initial_parameter(self) initial_parameter(self)
def forward(self, head, dep): def forward(self, head, dep):
""" """


@@ -209,11 +218,12 @@ class LabelBilinear(nn.Module):
:param num_label: 边类别的个数 :param num_label: 边类别的个数
:param bias: 是否使用bias. Default: ``True`` :param bias: 是否使用bias. Default: ``True``
""" """
def __init__(self, in1_features, in2_features, num_label, bias=True): def __init__(self, in1_features, in2_features, num_label, bias=True):
super(LabelBilinear, self).__init__() super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False)
def forward(self, x1, x2): def forward(self, x1, x2):
""" """


@@ -225,13 +235,13 @@ class LabelBilinear(nn.Module):
output += self.lin(torch.cat([x1, x2], dim=2)) output += self.lin(torch.cat([x1, x2], dim=2))
return output return output



class BiaffineParser(GraphParser): class BiaffineParser(GraphParser):
""" """
别名::class:`fastNLP.models.BiaffineParser` :class:`fastNLP.models.baffine_parser.BiaffineParser` 别名::class:`fastNLP.models.BiaffineParser` :class:`fastNLP.models.baffine_parser.BiaffineParser`


Biaffine Dependency Parser 实现. Biaffine Dependency Parser 实现.
论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
<https://arxiv.org/abs/1611.01734>`_ .
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ .


:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
@@ -248,18 +258,19 @@ class BiaffineParser(GraphParser):
:param use_greedy_infer: 是否在inference时使用贪心算法. :param use_greedy_infer: 是否在inference时使用贪心算法.
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False``
""" """
def __init__(self, def __init__(self,
init_embed,
pos_vocab_size,
pos_emb_dim,
num_label,
rnn_layers=1,
rnn_hidden_size=200,
arc_mlp_size=100,
label_mlp_size=100,
dropout=0.3,
encoder='lstm',
use_greedy_infer=False):
init_embed,
pos_vocab_size,
pos_emb_dim,
num_label,
rnn_layers=1,
rnn_hidden_size=200,
arc_mlp_size=100,
label_mlp_size=100,
dropout=0.3,
encoder='lstm',
use_greedy_infer=False):
super(BiaffineParser, self).__init__() super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size rnn_out_size = 2 * rnn_hidden_size
word_hid_dim = pos_hid_dim = rnn_hidden_size word_hid_dim = pos_hid_dim = rnn_hidden_size
@@ -295,20 +306,20 @@ class BiaffineParser(GraphParser):
if (d_k * n_head) != rnn_out_size: if (d_k * n_head) != rnn_out_size:
raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size))
self.position_emb = nn.Embedding(num_embeddings=self.max_len, self.position_emb = nn.Embedding(num_embeddings=self.max_len,
embedding_dim=rnn_out_size,)
embedding_dim=rnn_out_size, )
self.encoder = TransformerEncoder(num_layers=rnn_layers, self.encoder = TransformerEncoder(num_layers=rnn_layers,
model_size=rnn_out_size, model_size=rnn_out_size,
inner_size=1024, inner_size=1024,
key_size=d_k, key_size=d_k,
value_size=d_v, value_size=d_v,
num_head=n_head, num_head=n_head,
dropout=dropout,)
dropout=dropout, )
else: else:
raise ValueError('unsupported encoder type: {}'.format(encoder)) raise ValueError('unsupported encoder type: {}'.format(encoder))
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
nn.ELU(),
TimestepDropout(p=dropout),)
nn.ELU(),
TimestepDropout(p=dropout), )
self.arc_mlp_size = arc_mlp_size self.arc_mlp_size = arc_mlp_size
self.label_mlp_size = label_mlp_size self.label_mlp_size = label_mlp_size
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
@@ -316,7 +327,7 @@ class BiaffineParser(GraphParser):
self.use_greedy_infer = use_greedy_infer self.use_greedy_infer = use_greedy_infer
self.reset_parameters() self.reset_parameters()
self.dropout = dropout self.dropout = dropout
def reset_parameters(self): def reset_parameters(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Embedding): if isinstance(m, nn.Embedding):
@@ -327,7 +338,7 @@ class BiaffineParser(GraphParser):
else: else:
for p in m.parameters(): for p in m.parameters():
nn.init.normal_(p, 0, 0.1) nn.init.normal_(p, 0, 0.1)
def forward(self, words1, words2, seq_len, target1=None): def forward(self, words1, words2, seq_len, target1=None):
"""模型forward阶段 """模型forward阶段


@@ -337,50 +348,52 @@ class BiaffineParser(GraphParser):
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, :param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器
Default: ``None`` Default: ``None``
:return dict: parsing结果::
:return dict: parsing
结果::

pred1: [batch_size, seq_len, seq_len] 边预测logits
pred2: [batch_size, seq_len, num_label] label预测logits
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测


pred1: [batch_size, seq_len, seq_len] 边预测logits
pred2: [batch_size, seq_len, num_label] label预测logits
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测
""" """
# prepare embeddings # prepare embeddings
batch_size, length = words1.shape batch_size, length = words1.shape
# print('forward {} {}'.format(batch_size, seq_len)) # print('forward {} {}'.format(batch_size, seq_len))
# get sequence mask # get sequence mask
mask = seq_len_to_mask(seq_len).long() mask = seq_len_to_mask(seq_len).long()
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0]
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1]
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0]
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos) word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos) word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
# encoder, extract features # encoder, extract features
if self.encoder_name.endswith('lstm'): if self.encoder_name.endswith('lstm'):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
x = x[sort_idx] x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx] feat = feat[unsort_idx]
else: else:
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:]
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :]
x = x + self.position_emb(seq_range) x = x + self.position_emb(seq_range)
feat = self.encoder(x, mask.float()) feat = self.encoder(x, mask.float())
# for arc biaffine # for arc biaffine
# mlp, reduce dim # mlp, reduce dim
feat = self.mlp(feat) feat = self.mlp(feat)
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]
arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz]
label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:]
# biaffine arc classifier # biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
# use gold or predicted arc to predict label # use gold or predicted arc to predict label
if target1 is None or not self.training: if target1 is None or not self.training:
# use greedy decoding in training # use greedy decoding in training
@@ -390,22 +403,22 @@ class BiaffineParser(GraphParser):
heads = self.mst_decoder(arc_pred, mask) heads = self.mst_decoder(arc_pred, mask)
head_pred = heads head_pred = heads
else: else:
assert self.training # must be training mode
assert self.training # must be training mode
if target1 is None: if target1 is None:
heads = self.greedy_decoder(arc_pred, mask) heads = self.greedy_decoder(arc_pred, mask)
head_pred = heads head_pred = heads
else: else:
head_pred = None head_pred = None
heads = target1 heads = target1
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous() label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred}
if head_pred is not None: if head_pred is not None:
res_dict[C.OUTPUTS(2)] = head_pred res_dict[C.OUTPUTS(2)] = head_pred
return res_dict return res_dict
@staticmethod @staticmethod
def loss(pred1, pred2, target1, target2, seq_len): def loss(pred1, pred2, target1, target2, seq_len):
""" """
@@ -418,7 +431,7 @@ class BiaffineParser(GraphParser):
:param seq_len: [batch_size, seq_len] 真实目标的长度 :param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar :return loss: scalar
""" """
batch_size, length, _ = pred1.shape batch_size, length, _ = pred1.shape
mask = seq_len_to_mask(seq_len) mask = seq_len_to_mask(seq_len)
flip_mask = (mask == 0) flip_mask = (mask == 0)
@@ -430,24 +443,26 @@ class BiaffineParser(GraphParser):
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, target1] arc_loss = arc_logits[batch_index, child_index, target1]
label_loss = label_logits[batch_index, child_index, target2] label_loss = label_logits[batch_index, child_index, target2]
byte_mask = flip_mask.byte() byte_mask = flip_mask.byte()
arc_loss.masked_fill_(byte_mask, 0) arc_loss.masked_fill_(byte_mask, 0)
label_loss.masked_fill_(byte_mask, 0) label_loss.masked_fill_(byte_mask, 0)
arc_nll = -arc_loss.mean() arc_nll = -arc_loss.mean()
label_nll = -label_loss.mean() label_nll = -label_loss.mean()
return arc_nll + label_nll return arc_nll + label_nll
def predict(self, words1, words2, seq_len): def predict(self, words1, words2, seq_len):
"""模型预测API """模型预测API


:param words1: [batch_size, seq_len] 输入word序列 :param words1: [batch_size, seq_len] 输入word序列
:param words2: [batch_size, seq_len] 输入pos序列 :param words2: [batch_size, seq_len] 输入pos序列
:param seq_len: [batch_size, seq_len] 输入序列长度 :param seq_len: [batch_size, seq_len] 输入序列长度
:return dict: parsing结果::
:return dict: parsing
结果::

pred1: [batch_size, seq_len] heads的预测结果
pred2: [batch_size, seq_len, num_label] label预测logits


pred1: [batch_size, seq_len] heads的预测结果
pred2: [batch_size, seq_len, num_label] label预测logits
""" """
res = self(words1, words2, seq_len) res = self(words1, words2, seq_len)
output = {} output = {}
@@ -470,6 +485,7 @@ class ParserLoss(LossFunc):
:param seq_len: [batch_size, seq_len] 真实目标的长度 :param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar :return loss: scalar
""" """
def __init__(self, pred1=None, pred2=None, def __init__(self, pred1=None, pred2=None,
target1=None, target2=None, target1=None, target2=None,
seq_len=None): seq_len=None):
@@ -497,9 +513,10 @@ class ParserMetric(MetricBase):
UAS: 不带label时, 边预测的准确率 UAS: 不带label时, 边预测的准确率
LAS: 同时预测边和label的准确率 LAS: 同时预测边和label的准确率
""" """
def __init__(self, pred1=None, pred2=None, def __init__(self, pred1=None, pred2=None,
target1=None, target2=None, seq_len=None): target1=None, target2=None, seq_len=None):
super().__init__() super().__init__()
self._init_param_map(pred1=pred1, pred2=pred2, self._init_param_map(pred1=pred1, pred2=pred2,
target1=target1, target2=target2, target1=target1, target2=target2,
@@ -507,13 +524,13 @@ class ParserMetric(MetricBase):
self.num_arc = 0 self.num_arc = 0
self.num_label = 0 self.num_label = 0
self.num_sample = 0 self.num_sample = 0
def get_metric(self, reset=True): def get_metric(self, reset=True):
res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample}
res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample}
if reset: if reset:
self.num_sample = self.num_label = self.num_arc = 0 self.num_sample = self.num_label = self.num_arc = 0
return res return res
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): def evaluate(self, pred1, pred2, target1, target2, seq_len=None):
"""Evaluate the performance of prediction. """Evaluate the performance of prediction.
""" """
@@ -522,7 +539,7 @@ class ParserMetric(MetricBase):
else: else:
seq_mask = seq_len_to_mask(seq_len.long()).long() seq_mask = seq_len_to_mask(seq_len.long()).long()
# mask out <root> tag # mask out <root> tag
seq_mask[:,0] = 0
seq_mask[:, 0] = 0
head_pred_correct = (pred1 == target1).long() * seq_mask head_pred_correct = (pred1 == target1).long() * seq_mask
label_pred_correct = (pred2 == target2).long() * head_pred_correct label_pred_correct = (pred2 == target2).long() * head_pred_correct
self.num_arc += head_pred_correct.sum().item() self.num_arc += head_pred_correct.sum().item()


+ 8
- 7
fastNLP/models/cnn_text_classification.py View File

@@ -1,10 +1,11 @@
# python: 3.6
# encoding: utf-8
__all__ = [
"CNNText"
]


import torch import torch
import torch.nn as nn import torch.nn as nn
from ..core.const import Const as C


from ..core.const import Const as C
from ..modules import encoder from ..modules import encoder




@@ -23,7 +24,7 @@ class CNNText(torch.nn.Module):
:param int padding: 对句子前后的pad的大小, 用0填充。 :param int padding: 对句子前后的pad的大小, 用0填充。
:param float dropout: Dropout的大小 :param float dropout: Dropout的大小
""" """
def __init__(self, init_embed, def __init__(self, init_embed,
num_classes, num_classes,
kernel_nums=(3, 4, 5), kernel_nums=(3, 4, 5),
@@ -31,7 +32,7 @@ class CNNText(torch.nn.Module):
padding=0, padding=0,
dropout=0.5): dropout=0.5):
super(CNNText, self).__init__() super(CNNText, self).__init__()
# no support for pre-trained embedding currently # no support for pre-trained embedding currently
self.embed = encoder.Embedding(init_embed) self.embed = encoder.Embedding(init_embed)
self.conv_pool = encoder.ConvMaxpool( self.conv_pool = encoder.ConvMaxpool(
@@ -41,7 +42,7 @@ class CNNText(torch.nn.Module):
padding=padding) padding=padding)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(sum(kernel_nums), num_classes) self.fc = nn.Linear(sum(kernel_nums), num_classes)
def forward(self, words, seq_len=None): def forward(self, words, seq_len=None):
""" """


@@ -54,7 +55,7 @@ class CNNText(torch.nn.Module):
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class] x = self.fc(x) # [N,C] -> [N, N_class]
return {C.OUTPUT: x} return {C.OUTPUT: x}
def predict(self, words, seq_len=None): def predict(self, words, seq_len=None):
""" """
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index :param torch.LongTensor words: [batch_size, seq_len],句子中word的index


+ 1
- 0
fastNLP/models/enas_controller.py View File

@@ -5,6 +5,7 @@ import os


import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

from . import enas_utils as utils from . import enas_utils as utils
from .enas_utils import Node from .enas_utils import Node




+ 71
- 68
fastNLP/models/enas_model.py View File

@@ -1,17 +1,19 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
"""Module containing the shared RNN model."""
import numpy as np
"""
Module containing the shared RNN model.
Code Modified from https://github.com/carpedm20/ENAS-pytorch
"""
import collections import collections


import numpy as np
import torch import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable


from . import enas_utils as utils from . import enas_utils as utils
from .base_model import BaseModel from .base_model import BaseModel



def _get_dropped_weights(w_raw, dropout_p, is_training): def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect. """Drops out weights to implement DropConnect.


@@ -35,12 +37,13 @@ def _get_dropped_weights(w_raw, dropout_p, is_training):
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
""" """
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)
if isinstance(dropped_w, torch.nn.Parameter): if isinstance(dropped_w, torch.nn.Parameter):
dropped_w = dropped_w.clone() dropped_w = dropped_w.clone()
return dropped_w return dropped_w



class EmbeddingDropout(torch.nn.Embedding): class EmbeddingDropout(torch.nn.Embedding):
"""Class for dropping out embeddings by zero'ing out parameters in the """Class for dropping out embeddings by zero'ing out parameters in the
embedding matrix. embedding matrix.
@@ -53,6 +56,7 @@ class EmbeddingDropout(torch.nn.Embedding):
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
Networks', (Gal and Ghahramani, 2016). Networks', (Gal and Ghahramani, 2016).
""" """
def __init__(self, def __init__(self,
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
@@ -83,14 +87,14 @@ class EmbeddingDropout(torch.nn.Embedding):
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
'and < 1.0') 'and < 1.0')
self.scale = scale self.scale = scale
def forward(self, inputs): # pylint:disable=arguments-differ def forward(self, inputs): # pylint:disable=arguments-differ
"""Embeds `inputs` with the dropped out embedding weight matrix.""" """Embeds `inputs` with the dropped out embedding weight matrix."""
if self.training: if self.training:
dropout = self.dropout dropout = self.dropout
else: else:
dropout = 0 dropout = 0
if dropout: if dropout:
mask = self.weight.data.new(self.weight.size(0), 1) mask = self.weight.data.new(self.weight.size(0), 1)
mask.bernoulli_(1 - dropout) mask.bernoulli_(1 - dropout)
@@ -101,7 +105,7 @@ class EmbeddingDropout(torch.nn.Embedding):
masked_weight = self.weight masked_weight = self.weight
if self.scale and self.scale != 1: if self.scale and self.scale != 1:
masked_weight = masked_weight * self.scale masked_weight = masked_weight * self.scale
return F.embedding(inputs, return F.embedding(inputs,
masked_weight, masked_weight,
max_norm=self.max_norm, max_norm=self.max_norm,
@@ -114,7 +118,7 @@ class LockedDropout(nn.Module):
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, x, dropout=0.5): def forward(self, x, dropout=0.5):
if not self.training or not dropout: if not self.training or not dropout:
return x return x
@@ -126,11 +130,12 @@ class LockedDropout(nn.Module):


class ENASModel(BaseModel): class ENASModel(BaseModel):
"""Shared RNN model.""" """Shared RNN model."""
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000):
super(ENASModel, self).__init__() super(ENASModel, self).__init__()
self.use_cuda = cuda self.use_cuda = cuda
self.shared_hid = shared_hid self.shared_hid = shared_hid
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.decoder = nn.Linear(self.shared_hid, num_classes) self.decoder = nn.Linear(self.shared_hid, num_classes)
@@ -139,16 +144,16 @@ class ENASModel(BaseModel):
dropout=0.1) dropout=0.1)
self.lockdrop = LockedDropout() self.lockdrop = LockedDropout()
self.dag = None self.dag = None
# Tie weights # Tie weights
# self.decoder.weight = self.encoder.weight # self.decoder.weight = self.encoder.weight
# Since W^{x, c} and W^{h, c} are always summed, there # Since W^{x, c} and W^{h, c} are always summed, there
# is no point duplicating their bias offset parameter. Likewise for # is no point duplicating their bias offset parameter. Likewise for
# W^{x, h} and W^{h, h}. # W^{x, h} and W^{h, h}.
self.w_xc = nn.Linear(shared_embed, self.shared_hid) self.w_xc = nn.Linear(shared_embed, self.shared_hid)
self.w_xh = nn.Linear(shared_embed, self.shared_hid) self.w_xh = nn.Linear(shared_embed, self.shared_hid)
# The raw weights are stored here because the hidden-to-hidden weights # The raw weights are stored here because the hidden-to-hidden weights
# are weight dropped on the forward pass. # are weight dropped on the forward pass.
self.w_hc_raw = torch.nn.Parameter( self.w_hc_raw = torch.nn.Parameter(
@@ -157,10 +162,10 @@ class ENASModel(BaseModel):
torch.Tensor(self.shared_hid, self.shared_hid)) torch.Tensor(self.shared_hid, self.shared_hid))
self.w_hc = None self.w_hc = None
self.w_hh = None self.w_hh = None
self.w_h = collections.defaultdict(dict) self.w_h = collections.defaultdict(dict)
self.w_c = collections.defaultdict(dict) self.w_c = collections.defaultdict(dict)
for idx in range(self.num_blocks): for idx in range(self.num_blocks):
for jdx in range(idx + 1, self.num_blocks): for jdx in range(idx + 1, self.num_blocks):
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, self.w_h[idx][jdx] = nn.Linear(self.shared_hid,
@@ -169,48 +174,47 @@ class ENASModel(BaseModel):
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, self.w_c[idx][jdx] = nn.Linear(self.shared_hid,
self.shared_hid, self.shared_hid,
bias=False) bias=False)
self._w_h = nn.ModuleList([self.w_h[idx][jdx] self._w_h = nn.ModuleList([self.w_h[idx][jdx]
for idx in self.w_h for idx in self.w_h
for jdx in self.w_h[idx]]) for jdx in self.w_h[idx]])
self._w_c = nn.ModuleList([self.w_c[idx][jdx] self._w_c = nn.ModuleList([self.w_c[idx][jdx]
for idx in self.w_c for idx in self.w_c
for jdx in self.w_c[idx]]) for jdx in self.w_c[idx]])
self.batch_norm = None self.batch_norm = None
# if args.mode == 'train': # if args.mode == 'train':
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) # self.batch_norm = nn.BatchNorm1d(self.shared_hid)
# else: # else:
# self.batch_norm = None # self.batch_norm = None
self.reset_parameters() self.reset_parameters()
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
def setDAG(self, dag): def setDAG(self, dag):
if self.dag is None: if self.dag is None:
self.dag = dag self.dag = dag
def forward(self, word_seq, hidden=None): def forward(self, word_seq, hidden=None):
inputs = torch.transpose(word_seq, 0, 1) inputs = torch.transpose(word_seq, 0, 1)
time_steps = inputs.size(0) time_steps = inputs.size(0)
batch_size = inputs.size(1) batch_size = inputs.size(1)


self.w_hh = _get_dropped_weights(self.w_hh_raw, self.w_hh = _get_dropped_weights(self.w_hh_raw,
0.5, 0.5,
self.training) self.training)
self.w_hc = _get_dropped_weights(self.w_hc_raw, self.w_hc = _get_dropped_weights(self.w_hc_raw,
0.5, 0.5,
self.training) self.training)
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden # hidden = self.static_init_hidden[batch_size] if hidden is None else hidden
hidden = self.static_init_hidden[batch_size] hidden = self.static_init_hidden[batch_size]
embed = self.encoder(inputs) embed = self.encoder(inputs)
embed = self.lockdrop(embed, 0.65 if self.training else 0) embed = self.lockdrop(embed, 0.65 if self.training else 0)
# The norm of hidden states are clipped here because # The norm of hidden states are clipped here because
# otherwise ENAS is especially prone to exploding activations on the # otherwise ENAS is especially prone to exploding activations on the
# forward pass. This could probably be fixed in a more elegant way, but # forward pass. This could probably be fixed in a more elegant way, but
@@ -226,7 +230,7 @@ class ENASModel(BaseModel):
for step in range(time_steps): for step in range(time_steps):
x_t = embed[step] x_t = embed[step]
logit, hidden = self.cell(x_t, hidden, self.dag) logit, hidden = self.cell(x_t, hidden, self.dag)
hidden_norms = hidden.norm(dim=-1) hidden_norms = hidden.norm(dim=-1)
max_norm = 25.0 max_norm = 25.0
if hidden_norms.data.max() > max_norm: if hidden_norms.data.max() > max_norm:
@@ -237,60 +241,60 @@ class ENASModel(BaseModel):
# because the PyTorch slicing and slice assignment is too # because the PyTorch slicing and slice assignment is too
# flaky. # flaky.
hidden_norms = hidden_norms.data.cpu().numpy() hidden_norms = hidden_norms.data.cpu().numpy()
clipped_num += 1 clipped_num += 1
if hidden_norms.max() > max_clipped_norm: if hidden_norms.max() > max_clipped_norm:
max_clipped_norm = hidden_norms.max() max_clipped_norm = hidden_norms.max()
clip_select = hidden_norms > max_norm clip_select = hidden_norms > max_norm
clip_norms = hidden_norms[clip_select] clip_norms = hidden_norms[clip_select]
mask = np.ones(hidden.size()) mask = np.ones(hidden.size())
normalizer = max_norm/clip_norms
normalizer = max_norm / clip_norms
normalizer = normalizer[:, np.newaxis] normalizer = normalizer[:, np.newaxis]
mask[clip_select] = normalizer mask[clip_select] = normalizer
if self.use_cuda: if self.use_cuda:
hidden *= torch.autograd.Variable( hidden *= torch.autograd.Variable(
torch.FloatTensor(mask).cuda(), requires_grad=False) torch.FloatTensor(mask).cuda(), requires_grad=False)
else: else:
hidden *= torch.autograd.Variable( hidden *= torch.autograd.Variable(
torch.FloatTensor(mask), requires_grad=False)
torch.FloatTensor(mask), requires_grad=False)
logits.append(logit) logits.append(logit)
h1tohT.append(hidden) h1tohT.append(hidden)
h1tohT = torch.stack(h1tohT) h1tohT = torch.stack(h1tohT)
output = torch.stack(logits) output = torch.stack(logits)
raw_output = output raw_output = output
output = self.lockdrop(output, 0.4 if self.training else 0) output = self.lockdrop(output, 0.4 if self.training else 0)
#Pooling
# Pooling
output = torch.mean(output, 0) output = torch.mean(output, 0)
decoded = self.decoder(output) decoded = self.decoder(output)
extra_out = {'dropped': decoded, extra_out = {'dropped': decoded,
'hiddens': h1tohT, 'hiddens': h1tohT,
'raw': raw_output} 'raw': raw_output}
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out}
def cell(self, x, h_prev, dag): def cell(self, x, h_prev, dag):
"""Computes a single pass through the discovered RNN cell.""" """Computes a single pass through the discovered RNN cell."""
c = {} c = {}
h = {} h = {}
f = {} f = {}
f[0] = self.get_f(dag[-1][0].name) f[0] = self.get_f(dag[-1][0].name)
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
(1 - c[0])*h_prev)
h[0] = (c[0] * f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
(1 - c[0]) * h_prev)
leaf_node_ids = [] leaf_node_ids = []
q = collections.deque() q = collections.deque()
q.append(0) q.append(0)
# Computes connections from the parent nodes `node_id` # Computes connections from the parent nodes `node_id`
# to their child nodes `next_id` recursively, skipping leaf nodes. A # to their child nodes `next_id` recursively, skipping leaf nodes. A
# leaf node is a node whose id == `self.num_blocks`. # leaf node is a node whose id == `self.num_blocks`.
@@ -306,10 +310,10 @@ class ENASModel(BaseModel):
while True: while True:
if len(q) == 0: if len(q) == 0:
break break
node_id = q.popleft() node_id = q.popleft()
nodes = dag[node_id] nodes = dag[node_id]
for next_node in nodes: for next_node in nodes:
next_id = next_node.id next_id = next_node.id
if next_id == self.num_blocks: if next_id == self.num_blocks:
@@ -317,38 +321,38 @@ class ENASModel(BaseModel):
assert len(nodes) == 1, ('parent of leaf node should have ' assert len(nodes) == 1, ('parent of leaf node should have '
'only one child') 'only one child')
continue continue
w_h = self.w_h[node_id][next_id] w_h = self.w_h[node_id][next_id]
w_c = self.w_c[node_id][next_id] w_c = self.w_c[node_id][next_id]
f[next_id] = self.get_f(next_node.name) f[next_id] = self.get_f(next_node.name)
c[next_id] = torch.sigmoid(w_c(h[node_id])) c[next_id] = torch.sigmoid(w_c(h[node_id]))
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
(1 - c[next_id])*h[node_id])
h[next_id] = (c[next_id] * f[next_id](w_h(h[node_id])) +
(1 - c[next_id]) * h[node_id])
q.append(next_id) q.append(next_id)
# Instead of averaging loose ends, perhaps there should # Instead of averaging loose ends, perhaps there should
# be a set of separate unshared weights for each "loose" connection # be a set of separate unshared weights for each "loose" connection
# between each node in a cell and the output. # between each node in a cell and the output.
# #
# As it stands, all weights W^h_{ij} are doing double duty by # As it stands, all weights W^h_{ij} are doing double duty by
# connecting both from i to j, as well as from i to the output. # connecting both from i to j, as well as from i to the output.
# average all the loose ends # average all the loose ends
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
output = torch.mean(torch.stack(leaf_nodes, 2), -1) output = torch.mean(torch.stack(leaf_nodes, 2), -1)
# stabilizing the Updates of omega # stabilizing the Updates of omega
if self.batch_norm is not None: if self.batch_norm is not None:
output = self.batch_norm(output) output = self.batch_norm(output)
return output, h[self.num_blocks - 1] return output, h[self.num_blocks - 1]
def init_hidden(self, batch_size): def init_hidden(self, batch_size):
zeros = torch.zeros(batch_size, self.shared_hid) zeros = torch.zeros(batch_size, self.shared_hid)
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) return utils.get_variable(zeros, self.use_cuda, requires_grad=False)
def get_f(self, name): def get_f(self, name):
name = name.lower() name = name.lower()
if name == 'relu': if name == 'relu':
@@ -360,22 +364,21 @@ class ENASModel(BaseModel):
elif name == 'sigmoid': elif name == 'sigmoid':
f = torch.sigmoid f = torch.sigmoid
return f return f

@property @property
def num_parameters(self): def num_parameters(self):
def size(p): def size(p):
return np.prod(p.size()) return np.prod(p.size())
return sum([size(param) for param in self.parameters()]) return sum([size(param) for param in self.parameters()])


def reset_parameters(self): def reset_parameters(self):
init_range = 0.025 init_range = 0.025
# init_range = 0.025 if self.args.mode == 'train' else 0.04 # init_range = 0.025 if self.args.mode == 'train' else 0.04
for param in self.parameters(): for param in self.parameters():
param.data.uniform_(-init_range, init_range) param.data.uniform_(-init_range, init_range)
self.decoder.bias.data.fill_(0) self.decoder.bias.data.fill_(0)
def predict(self, word_seq): def predict(self, word_seq):
""" """




+ 69
- 72
fastNLP/models/enas_trainer.py View File

@@ -1,12 +1,12 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch # Code Modified from https://github.com/carpedm20/ENAS-pytorch

import time
from datetime import datetime
from datetime import timedelta

import math
import numpy as np import numpy as np
import time
import torch import torch
import math

from datetime import datetime, timedelta

from torch.optim import Adam


try: try:
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -21,8 +21,6 @@ from ..core.utils import _move_dict_value_to_device
from . import enas_utils as utils from . import enas_utils as utils
from ..core.utils import _build_args from ..core.utils import _build_args


from torch.optim import Adam



def _get_no_grad_ctx_mgr(): def _get_no_grad_ctx_mgr():
"""Returns a the `torch.no_grad` context manager for PyTorch version >= """Returns a the `torch.no_grad` context manager for PyTorch version >=
@@ -33,6 +31,7 @@ def _get_no_grad_ctx_mgr():


class ENASTrainer(Trainer): class ENASTrainer(Trainer):
"""A class to wrap training code.""" """A class to wrap training code."""
def __init__(self, train_data, model, controller, **kwargs): def __init__(self, train_data, model, controller, **kwargs):
"""Constructor for training algorithm. """Constructor for training algorithm.
:param DataSet train_data: the training data :param DataSet train_data: the training data
@@ -45,19 +44,19 @@ class ENASTrainer(Trainer):
self.controller_step = 0 self.controller_step = 0
self.shared_step = 0 self.shared_step = 0
self.max_length = 35 self.max_length = 35
self.shared = model self.shared = model
self.controller = controller self.controller = controller
self.shared_optim = Adam( self.shared_optim = Adam(
self.shared.parameters(), self.shared.parameters(),
lr=20.0, lr=20.0,
weight_decay=1e-7) weight_decay=1e-7)
self.controller_optim = Adam( self.controller_optim = Adam(
self.controller.parameters(), self.controller.parameters(),
lr=3.5e-4) lr=3.5e-4)
def train(self, load_best_model=True): def train(self, load_best_model=True):
""" """
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
@@ -82,21 +81,22 @@ class ENASTrainer(Trainer):
self.model = self.model.cuda() self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device self._model_device = self.model.parameters().__next__().device
self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time() start_time = time.time()
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)
try: try:
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.on_train_end() self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e) self.callback_manager.on_exception(e)
if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),)
print(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf), )
results['best_eval'] = self.best_dev_perf results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step results['best_step'] = self.best_dev_step
@@ -110,9 +110,9 @@ class ENASTrainer(Trainer):
finally: finally:
pass pass
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)
return results return results
def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
@@ -126,21 +126,21 @@ class ENASTrainer(Trainer):
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch) prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1):
for epoch in range(1, self.n_epochs + 1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
if epoch == self.n_epochs + 1 - self.final_epochs: if epoch == self.n_epochs + 1 - self.final_epochs:
print('Entering the final stage. (Only train the selected structure)') print('Entering the final stage. (Only train the selected structure)')
# early stopping # early stopping
self.callback_manager.on_epoch_begin() self.callback_manager.on_epoch_begin()
# 1. Training the shared parameters omega of the child models # 1. Training the shared parameters omega of the child models
self.train_shared(pbar) self.train_shared(pbar)
# 2. Training the controller parameters theta # 2. Training the controller parameters theta
if not last_stage: if not last_stage:
self.train_controller() self.train_controller()
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None: and self.dev_data is not None:
@@ -149,16 +149,15 @@ class ENASTrainer(Trainer):
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \ total_steps) + \
self.tester._format_eval_results(eval_res)
self.tester._format_eval_results(eval_res)
pbar.write(eval_str) pbar.write(eval_str)
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
# ============ tqdm end ============== # # ============ tqdm end ============== #


def get_loss(self, inputs, targets, hidden, dags): def get_loss(self, inputs, targets, hidden, dags):
"""Computes the loss for the same batch for M models. """Computes the loss for the same batch for M models.


@@ -167,7 +166,7 @@ class ENASTrainer(Trainer):
""" """
if not isinstance(dags, list): if not isinstance(dags, list):
dags = [dags] dags = [dags]
loss = 0 loss = 0
for dag in dags: for dag in dags:
self.shared.setDAG(dag) self.shared.setDAG(dag)
@@ -175,14 +174,14 @@ class ENASTrainer(Trainer):
inputs['hidden'] = hidden inputs['hidden'] = hidden
result = self.shared(**inputs) result = self.shared(**inputs)
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out']
self.callback_manager.on_loss_begin(targets, result) self.callback_manager.on_loss_begin(targets, result)
sample_loss = self._compute_loss(result, targets) sample_loss = self._compute_loss(result, targets)
loss += sample_loss loss += sample_loss
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
return loss, hidden, extra_out return loss, hidden, extra_out
def train_shared(self, pbar=None, max_step=None, dag=None): def train_shared(self, pbar=None, max_step=None, dag=None):
"""Train the language model for 400 steps of minibatches of 64 """Train the language model for 400 steps of minibatches of 64
examples. examples.
@@ -200,9 +199,9 @@ class ENASTrainer(Trainer):
model = self.shared model = self.shared
model.train() model.train()
self.controller.eval() self.controller.eval()
hidden = self.shared.init_hidden(self.batch_size) hidden = self.shared.init_hidden(self.batch_size)
abs_max_grad = 0 abs_max_grad = 0
abs_max_hidden_norm = 0 abs_max_hidden_norm = 0
step = 0 step = 0
@@ -211,15 +210,15 @@ class ENASTrainer(Trainer):
train_idx = 0 train_idx = 0
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
prefetch=self.prefetch)
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
# prediction = self._data_forward(self.model, batch_x) # prediction = self._data_forward(self.model, batch_x)
dags = self.controller.sample(1) dags = self.controller.sample(1)
inputs, targets = batch_x, batch_y inputs, targets = batch_x, batch_y
# self.callback_manager.on_loss_begin(batch_y, prediction) # self.callback_manager.on_loss_begin(batch_y, prediction)
@@ -228,18 +227,18 @@ class ENASTrainer(Trainer):
hidden, hidden,
dags) dags)
hidden.detach_() hidden.detach_()
avg_loss += loss.item() avg_loss += loss.item()
# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss) self._grad_backward(loss)
self.callback_manager.on_backward_end() self.callback_manager.on_backward_end()
self._update() self._update()
self.callback_manager.on_step_end() self.callback_manager.on_step_end()
if (self.step+1) % self.print_every == 0:
if (self.step + 1) % self.print_every == 0:
if self.use_tqdm: if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
pbar.update(self.print_every) pbar.update(self.print_every)
@@ -255,30 +254,29 @@ class ENASTrainer(Trainer):
self.shared_step += 1 self.shared_step += 1
self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


def get_reward(self, dag, entropies, hidden, valid_idx=0): def get_reward(self, dag, entropies, hidden, valid_idx=0):
"""Computes the perplexity of a single sampled model on a minibatch of """Computes the perplexity of a single sampled model on a minibatch of
validation data. validation data.
""" """
if not isinstance(entropies, np.ndarray): if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy() entropies = entropies.data.cpu().numpy()
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
prefetch=self.prefetch)
for inputs, targets in data_iterator: for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
valid_loss = utils.to_item(valid_loss.data) valid_loss = utils.to_item(valid_loss.data)
valid_ppl = math.exp(valid_loss) valid_ppl = math.exp(valid_loss)
R = 80 / valid_ppl R = 80 / valid_ppl
rewards = R + 1e-4 * entropies rewards = R + 1e-4 * entropies
return rewards, hidden return rewards, hidden
def train_controller(self): def train_controller(self):
"""Fixes the shared parameters and updates the controller parameters. """Fixes the shared parameters and updates the controller parameters.


@@ -296,13 +294,13 @@ class ENASTrainer(Trainer):
# Why can't we call shared.eval() here? Leads to loss # Why can't we call shared.eval() here? Leads to loss
# being uniformly zero for the controller. # being uniformly zero for the controller.
# self.shared.eval() # self.shared.eval()
avg_reward_base = None avg_reward_base = None
baseline = None baseline = None
adv_history = [] adv_history = []
entropy_history = [] entropy_history = []
reward_history = [] reward_history = []
hidden = self.shared.init_hidden(self.batch_size) hidden = self.shared.init_hidden(self.batch_size)
total_loss = 0 total_loss = 0
valid_idx = 0 valid_idx = 0
@@ -310,7 +308,7 @@ class ENASTrainer(Trainer):
# sample models # sample models
dags, log_probs, entropies = self.controller.sample( dags, log_probs, entropies = self.controller.sample(
with_details=True) with_details=True)
# calculate reward # calculate reward
np_entropies = entropies.data.cpu().numpy() np_entropies = entropies.data.cpu().numpy()
# No gradients should be backpropagated to the # No gradients should be backpropagated to the
@@ -320,40 +318,39 @@ class ENASTrainer(Trainer):
np_entropies, np_entropies,
hidden, hidden,
valid_idx) valid_idx)


reward_history.extend(rewards) reward_history.extend(rewards)
entropy_history.extend(np_entropies) entropy_history.extend(np_entropies)
# moving average baseline # moving average baseline
if baseline is None: if baseline is None:
baseline = rewards baseline = rewards
else: else:
decay = 0.95 decay = 0.95
baseline = decay * baseline + (1 - decay) * rewards baseline = decay * baseline + (1 - decay) * rewards
adv = rewards - baseline adv = rewards - baseline
adv_history.extend(adv) adv_history.extend(adv)
# policy loss # policy loss
loss = -log_probs*utils.get_variable(adv,
'cuda' in self.device,
requires_grad=False)
loss = -log_probs * utils.get_variable(adv,
'cuda' in self.device,
requires_grad=False)
loss = loss.sum() # or loss.mean() loss = loss.sum() # or loss.mean()
# update # update
self.controller_optim.zero_grad() self.controller_optim.zero_grad()
loss.backward() loss.backward()
self.controller_optim.step() self.controller_optim.step()
total_loss += utils.to_item(loss.data) total_loss += utils.to_item(loss.data)
if ((step % 50) == 0) and (step > 0): if ((step % 50) == 0) and (step > 0):
reward_history, adv_history, entropy_history = [], [], [] reward_history, adv_history, entropy_history = [], [], []
total_loss = 0 total_loss = 0
self.controller_step += 1 self.controller_step += 1
# prev_valid_idx = valid_idx # prev_valid_idx = valid_idx
# valid_idx = ((valid_idx + self.max_length) % # valid_idx = ((valid_idx + self.max_length) %
@@ -362,16 +359,16 @@ class ENASTrainer(Trainer):
# # validation data, we reset the hidden states. # # validation data, we reset the hidden states.
# if prev_valid_idx > valid_idx: # if prev_valid_idx > valid_idx:
# hidden = self.shared.init_hidden(self.batch_size) # hidden = self.shared.init_hidden(self.batch_size)
def derive(self, sample_num=10, valid_idx=0): def derive(self, sample_num=10, valid_idx=0):
"""We are always deriving based on the very first batch """We are always deriving based on the very first batch
of validation data? This seems wrong... of validation data? This seems wrong...
""" """
hidden = self.shared.init_hidden(self.batch_size) hidden = self.shared.init_hidden(self.batch_size)
dags, _, entropies = self.controller.sample(sample_num, dags, _, entropies = self.controller.sample(sample_num,
with_details=True) with_details=True)
max_R = 0 max_R = 0
best_dag = None best_dag = None
for dag in dags: for dag in dags:
@@ -379,5 +376,5 @@ class ENASTrainer(Trainer):
if R.max() > max_R: if R.max() > max_R:
max_R = R.max() max_R = R.max()
best_dag = dag best_dag = dag
self.model.setDAG(best_dag) self.model.setDAG(best_dag)

+ 0
- 3
fastNLP/models/enas_utils.py View File

@@ -1,12 +1,9 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch # Code Modified from https://github.com/carpedm20/ENAS-pytorch


from __future__ import print_function

from collections import defaultdict from collections import defaultdict
import collections import collections


import numpy as np import numpy as np

import torch import torch
from torch.autograd import Variable from torch.autograd import Variable




+ 13
- 5
fastNLP/models/sequence_labeling.py View File

@@ -1,11 +1,19 @@
"""
本模块实现了两种序列标注模型
"""
__all__ = [
"SeqLabeling",
"AdvSeqLabel"
]

import torch import torch
import torch.nn as nn


from .base_model import BaseModel from .base_model import BaseModel
from ..modules import decoder, encoder from ..modules import decoder, encoder
from ..modules.decoder.CRF import allowed_transitions
from ..modules.decoder.crf import allowed_transitions
from ..core.utils import seq_len_to_mask from ..core.utils import seq_len_to_mask
from ..core.const import Const as C from ..core.const import Const as C
from torch import nn




class SeqLabeling(BaseModel): class SeqLabeling(BaseModel):
@@ -27,7 +35,7 @@ class SeqLabeling(BaseModel):
self.Embedding = encoder.embedding.Embedding(init_embed) self.Embedding = encoder.embedding.Embedding(init_embed)
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size)
self.Linear = nn.Linear(hidden_size, num_classes) self.Linear = nn.Linear(hidden_size, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.Crf = decoder.crf.ConditionalRandomField(num_classes)
self.mask = None self.mask = None
def forward(self, words, seq_len, target): def forward(self, words, seq_len, target):
@@ -133,9 +141,9 @@ class AdvSeqLabel(nn.Module):
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes)
if id2words is None: if id2words is None:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False)
else: else:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False,
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False,
allowed_transitions=allowed_transitions(id2words, allowed_transitions=allowed_transitions(id2words,
encoding_type=encoding_type)) encoding_type=encoding_type))


+ 33
- 30
fastNLP/models/snli.py View File

@@ -1,3 +1,7 @@
__all__ = [
"ESIM"
]

import torch import torch
import torch.nn as nn import torch.nn as nn


@@ -8,7 +12,6 @@ from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator from ..modules import aggregator as Aggregator
from ..core.utils import seq_len_to_mask from ..core.utils import seq_len_to_mask



my_inf = 10e12 my_inf = 10e12




@@ -26,7 +29,7 @@ class ESIM(BaseModel):
:param int num_classes: 标签数目,默认为3 :param int num_classes: 标签数目,默认为3
:param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 :param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵
""" """
def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None):
super(ESIM, self).__init__() super(ESIM, self).__init__()
@@ -35,35 +38,36 @@ class ESIM(BaseModel):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout = dropout self.dropout = dropout
self.n_labels = num_classes self.n_labels = num_classes
self.drop = nn.Dropout(self.dropout) self.drop = nn.Dropout(self.dropout)
self.embedding = Encoder.Embedding( self.embedding = Encoder.Embedding(
(self.vocab_size, self.embed_dim), dropout=self.dropout, (self.vocab_size, self.embed_dim), dropout=self.dropout,
) )
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size)
self.encoder = Encoder.LSTM( self.encoder = Encoder.LSTM(
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True,
batch_first=True, bidirectional=True batch_first=True, bidirectional=True
) )
self.bi_attention = Aggregator.BiAttention() self.bi_attention = Aggregator.BiAttention()
self.mean_pooling = Aggregator.AvgPoolWithMask() self.mean_pooling = Aggregator.AvgPoolWithMask()
self.max_pooling = Aggregator.MaxPoolWithMask() self.max_pooling = Aggregator.MaxPoolWithMask()
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size)
self.decoder = Encoder.LSTM( self.decoder = Encoder.LSTM(
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True,
batch_first=True, bidirectional=True batch_first=True, bidirectional=True
) )
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout)
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None):
""" Forward function """ Forward function
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示
:param torch.LongTensor seq_len1: [B] premise的长度 :param torch.LongTensor seq_len1: [B] premise的长度
@@ -71,10 +75,10 @@ class ESIM(BaseModel):
:param torch.LongTensor target: [B] 真实目标值 :param torch.LongTensor target: [B] 真实目标值
:return: dict prediction: [B, n_labels(N)] 预测结果 :return: dict prediction: [B, n_labels(N)] 预测结果
""" """
premise0 = self.embedding_layer(self.embedding(words1)) premise0 = self.embedding_layer(self.embedding(words1))
hypothesis0 = self.embedding_layer(self.embedding(words2)) hypothesis0 = self.embedding_layer(self.embedding(words2))
if seq_len1 is not None: if seq_len1 is not None:
seq_len1 = seq_len_to_mask(seq_len1) seq_len1 = seq_len_to_mask(seq_len1)
else: else:
@@ -85,55 +89,55 @@ class ESIM(BaseModel):
else: else:
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1))
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) seq_len2 = (seq_len2.long()).to(device=hypothesis0.device)
_BP, _PSL, _HP = premise0.size() _BP, _PSL, _HP = premise0.size()
_BH, _HSL, _HH = hypothesis0.size() _BH, _HSL, _HH = hypothesis0.size()
_BPL, _PLL = seq_len1.size() _BPL, _PLL = seq_len1.size()
_HPL, _HLL = seq_len2.size() _HPL, _HLL = seq_len2.size()
assert _BP == _BH and _BPL == _HPL and _BP == _BPL assert _BP == _BH and _BPL == _HPL and _BP == _BPL
assert _HP == _HH assert _HP == _HH
assert _PSL == _PLL and _HSL == _HLL assert _PSL == _PLL and _HSL == _HLL
B, PL, H = premise0.size() B, PL, H = premise0.size()
B, HL, H = hypothesis0.size() B, HL, H = hypothesis0.size()
a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2]
b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2]
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H]
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H]
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) ai, bi = self.bi_attention(a, b, seq_len1, seq_len2)
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H]
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H]
f_ma = self.inference_layer(ma) f_ma = self.inference_layer(ma)
f_mb = self.inference_layer(mb) f_mb = self.inference_layer(mb)
vat = self.decoder(self.drop(f_ma)) vat = self.decoder(self.drop(f_ma))
vbt = self.decoder(self.drop(f_mb)) vbt = self.decoder(self.drop(f_mb))
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H]
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H]
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H]
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H]
prediction = torch.tanh(self.output(v)) # prediction: [B, N] prediction = torch.tanh(self.output(v)) # prediction: [B, N]
if target is not None: if target is not None:
func = nn.CrossEntropyLoss() func = nn.CrossEntropyLoss()
loss = func(prediction, target) loss = func(prediction, target)
return {Const.OUTPUT: prediction, Const.LOSS: loss} return {Const.OUTPUT: prediction, Const.LOSS: loss}
return {Const.OUTPUT: prediction} return {Const.OUTPUT: prediction}
def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None):
""" Predict function """ Predict function


@@ -146,4 +150,3 @@ class ESIM(BaseModel):
""" """
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} return {Const.OUTPUT: torch.argmax(prediction, dim=-1)}


+ 41
- 28
fastNLP/models/star_transformer.py View File

@@ -1,17 +1,25 @@
"""Star-Transformer 的 一个 Pytorch 实现.
""" """
Star-Transformer 的 Pytorch 实现。
"""
__all__ = [
"StarTransEnc",
"STNLICls",
"STSeqCls",
"STSeqLabel",
]

import torch
from torch import nn

from ..modules.encoder.star_transformer import StarTransformer from ..modules.encoder.star_transformer import StarTransformer
from ..core.utils import seq_len_to_mask from ..core.utils import seq_len_to_mask
from ..modules.utils import get_embeddings from ..modules.utils import get_embeddings
from ..core.const import Const from ..core.const import Const


import torch
from torch import nn



class StarTransEnc(nn.Module): class StarTransEnc(nn.Module):
""" """
别名::class:`fastNLP.models.StarTransEnc` :class:`fastNLP.models.start_transformer.StarTransEnc`
别名::class:`fastNLP.models.StarTransEnc` :class:`fastNLP.models.star_transformer.StarTransEnc`


带word embedding的Star-Transformer Encoder 带word embedding的Star-Transformer Encoder


@@ -28,6 +36,7 @@ class StarTransEnc(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. :param emb_dropout: 词嵌入的dropout概率.
:param dropout: 模型除词嵌入外的dropout概率. :param dropout: 模型除词嵌入外的dropout概率.
""" """
def __init__(self, init_embed, def __init__(self, init_embed,
hidden_size, hidden_size,
num_layers, num_layers,
@@ -47,7 +56,7 @@ class StarTransEnc(nn.Module):
head_dim=head_dim, head_dim=head_dim,
dropout=dropout, dropout=dropout,
max_len=max_len) max_len=max_len)
def forward(self, x, mask): def forward(self, x, mask):
""" """
:param FloatTensor data: [batch, length, hidden] 输入的序列 :param FloatTensor data: [batch, length, hidden] 输入的序列
@@ -72,7 +81,7 @@ class _Cls(nn.Module):
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls), nn.Linear(hid_dim, num_cls),
) )
def forward(self, x): def forward(self, x):
h = self.fc(x) h = self.fc(x)
return h return h
@@ -83,20 +92,21 @@ class _NLICls(nn.Module):
super(_NLICls, self).__init__() super(_NLICls, self).__init__()
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(in_dim*4, hid_dim), #4
nn.Linear(in_dim * 4, hid_dim), # 4
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls), nn.Linear(hid_dim, num_cls),
) )
def forward(self, x1, x2): def forward(self, x1, x2):
x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1)
x = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], 1)
h = self.fc(x) h = self.fc(x)
return h return h



class STSeqLabel(nn.Module): class STSeqLabel(nn.Module):
""" """
别名::class:`fastNLP.models.STSeqLabel` :class:`fastNLP.models.start_transformer.STSeqLabel`
别名::class:`fastNLP.models.STSeqLabel` :class:`fastNLP.models.star_transformer.STSeqLabel`


用于序列标注的Star-Transformer模型 用于序列标注的Star-Transformer模型


@@ -112,6 +122,7 @@ class STSeqLabel(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
""" """
def __init__(self, init_embed, num_cls, def __init__(self, init_embed, num_cls,
hidden_size=300, hidden_size=300,
num_layers=4, num_layers=4,
@@ -120,7 +131,7 @@ class STSeqLabel(nn.Module):
max_len=512, max_len=512,
cls_hidden_size=600, cls_hidden_size=600,
emb_dropout=0.1, emb_dropout=0.1,
dropout=0.1,):
dropout=0.1, ):
super(STSeqLabel, self).__init__() super(STSeqLabel, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed, self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size, hidden_size=hidden_size,
@@ -131,7 +142,7 @@ class STSeqLabel(nn.Module):
emb_dropout=emb_dropout, emb_dropout=emb_dropout,
dropout=dropout) dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
def forward(self, words, seq_len): def forward(self, words, seq_len):
""" """


@@ -142,9 +153,9 @@ class STSeqLabel(nn.Module):
mask = seq_len_to_mask(seq_len) mask = seq_len_to_mask(seq_len)
nodes, _ = self.enc(words, mask) nodes, _ = self.enc(words, mask)
output = self.cls(nodes) output = self.cls(nodes)
output = output.transpose(1,2) # make hidden to be dim 1
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len]
output = output.transpose(1, 2) # make hidden to be dim 1
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len]
def predict(self, words, seq_len): def predict(self, words, seq_len):
""" """


@@ -159,7 +170,7 @@ class STSeqLabel(nn.Module):


class STSeqCls(nn.Module): class STSeqCls(nn.Module):
""" """
别名::class:`fastNLP.models.STSeqCls` :class:`fastNLP.models.start_transformer.STSeqCls`
别名::class:`fastNLP.models.STSeqCls` :class:`fastNLP.models.star_transformer.STSeqCls`


用于分类任务的Star-Transformer 用于分类任务的Star-Transformer


@@ -175,7 +186,7 @@ class STSeqCls(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
""" """
def __init__(self, init_embed, num_cls, def __init__(self, init_embed, num_cls,
hidden_size=300, hidden_size=300,
num_layers=4, num_layers=4,
@@ -184,7 +195,7 @@ class STSeqCls(nn.Module):
max_len=512, max_len=512,
cls_hidden_size=600, cls_hidden_size=600,
emb_dropout=0.1, emb_dropout=0.1,
dropout=0.1,):
dropout=0.1, ):
super(STSeqCls, self).__init__() super(STSeqCls, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed, self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size, hidden_size=hidden_size,
@@ -195,7 +206,7 @@ class STSeqCls(nn.Module):
emb_dropout=emb_dropout, emb_dropout=emb_dropout,
dropout=dropout) dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
def forward(self, words, seq_len): def forward(self, words, seq_len):
""" """


@@ -206,9 +217,9 @@ class STSeqCls(nn.Module):
mask = seq_len_to_mask(seq_len) mask = seq_len_to_mask(seq_len)
nodes, relay = self.enc(words, mask) nodes, relay = self.enc(words, mask)
y = 0.5 * (relay + nodes.max(1)[0]) y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls]
output = self.cls(y) # [bsz, n_cls]
return {Const.OUTPUT: output} return {Const.OUTPUT: output}
def predict(self, words, seq_len): def predict(self, words, seq_len):
""" """


@@ -223,7 +234,7 @@ class STSeqCls(nn.Module):


class STNLICls(nn.Module): class STNLICls(nn.Module):
""" """
别名::class:`fastNLP.models.STNLICls` :class:`fastNLP.models.start_transformer.STNLICls`
别名::class:`fastNLP.models.STNLICls` :class:`fastNLP.models.star_transformer.STNLICls`
用于自然语言推断(NLI)的Star-Transformer 用于自然语言推断(NLI)的Star-Transformer


@@ -239,7 +250,7 @@ class STNLICls(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
""" """
def __init__(self, init_embed, num_cls, def __init__(self, init_embed, num_cls,
hidden_size=300, hidden_size=300,
num_layers=4, num_layers=4,
@@ -248,7 +259,7 @@ class STNLICls(nn.Module):
max_len=512, max_len=512,
cls_hidden_size=600, cls_hidden_size=600,
emb_dropout=0.1, emb_dropout=0.1,
dropout=0.1,):
dropout=0.1, ):
super(STNLICls, self).__init__() super(STNLICls, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed, self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size, hidden_size=hidden_size,
@@ -259,7 +270,7 @@ class STNLICls(nn.Module):
emb_dropout=emb_dropout, emb_dropout=emb_dropout,
dropout=dropout) dropout=dropout)
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size)
def forward(self, words1, words2, seq_len1, seq_len2): def forward(self, words1, words2, seq_len1, seq_len2):
""" """


@@ -271,14 +282,16 @@ class STNLICls(nn.Module):
""" """
mask1 = seq_len_to_mask(seq_len1) mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2) mask2 = seq_len_to_mask(seq_len2)
def enc(seq, mask): def enc(seq, mask):
nodes, relay = self.enc(seq, mask) nodes, relay = self.enc(seq, mask)
return 0.5 * (relay + nodes.max(1)[0]) return 0.5 * (relay + nodes.max(1)[0])
y1 = enc(words1, mask1) y1 = enc(words1, mask1)
y2 = enc(words2, mask2) y2 = enc(words2, mask2)
output = self.cls(y1, y2) # [bsz, n_cls]
output = self.cls(y1, y2) # [bsz, n_cls]
return {Const.OUTPUT: output} return {Const.OUTPUT: output}
def predict(self, words1, words2, seq_len1, seq_len2): def predict(self, words1, words2, seq_len1, seq_len2):
""" """




+ 21
- 15
fastNLP/modules/__init__.py View File

@@ -22,29 +22,35 @@
+-----------------------+-----------------------+-----------------------+ +-----------------------+-----------------------+-----------------------+


""" """
from . import aggregator
from . import decoder
from . import encoder
from .aggregator import *
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
from .utils import get_embeddings

__all__ = [ __all__ = [
"LSTM",
"Embedding",
# "BertModel",
"ConvolutionCharEncoder",
"LSTMCharEncoder",
"ConvMaxpool", "ConvMaxpool",
"BertModel",
"Embedding",
"LSTM",
"StarTransformer",
"TransformerEncoder",
"VarRNN",
"VarLSTM",
"VarGRU",
"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"AvgPool", "AvgPool",
"MultiHeadAttention", "MultiHeadAttention",
"BiAttention",

"MLP", "MLP",
"ConditionalRandomField", "ConditionalRandomField",
"viterbi_decode", "viterbi_decode",
"allowed_transitions", "allowed_transitions",
]
]

from . import aggregator
from . import decoder
from . import encoder
from .aggregator import *
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
from .utils import get_embeddings

+ 7
- 7
fastNLP/modules/aggregator/__init__.py View File

@@ -1,14 +1,14 @@
from .pooling import MaxPool
from .pooling import MaxPoolWithMask
from .pooling import AvgPool
from .pooling import AvgPoolWithMask

from .attention import MultiHeadAttention, BiAttention
__all__ = [ __all__ = [
"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"AvgPool", "AvgPool",
"MultiHeadAttention", "MultiHeadAttention",
"BiAttention"
] ]

from .pooling import MaxPool
from .pooling import MaxPoolWithMask
from .pooling import AvgPool
from .pooling import AvgPoolWithMask

from .attention import MultiHeadAttention

+ 36
- 26
fastNLP/modules/aggregator/attention.py View File

@@ -1,4 +1,7 @@
__all__ =["MultiHeadAttention"]
__all__ = [
"MultiHeadAttention"
]

import math import math


import torch import torch
@@ -15,6 +18,7 @@ class DotAttention(nn.Module):
.. todo:: .. todo::
补上文档 补上文档
""" """
def __init__(self, key_size, value_size, dropout=0): def __init__(self, key_size, value_size, dropout=0):
super(DotAttention, self).__init__() super(DotAttention, self).__init__()
self.key_size = key_size self.key_size = key_size
@@ -22,7 +26,7 @@ class DotAttention(nn.Module):
self.scale = math.sqrt(key_size) self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=2) self.softmax = nn.Softmax(dim=2)
def forward(self, Q, K, V, mask_out=None): def forward(self, Q, K, V, mask_out=None):
""" """


@@ -41,6 +45,8 @@ class DotAttention(nn.Module):


class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
""" """
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention`



:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 :param input_size: int, 输入维度的大小。同时也是输出维度的大小。
:param key_size: int, 每个head的维度大小。 :param key_size: int, 每个head的维度大小。
@@ -48,13 +54,14 @@ class MultiHeadAttention(nn.Module):
:param num_head: int,head的数量。 :param num_head: int,head的数量。
:param dropout: float。 :param dropout: float。
""" """
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.input_size = input_size self.input_size = input_size
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
self.num_head = num_head self.num_head = num_head
in_size = key_size * num_head in_size = key_size * num_head
self.q_in = nn.Linear(input_size, in_size) self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size)
@@ -64,14 +71,14 @@ class MultiHeadAttention(nn.Module):
self.out = nn.Linear(value_size * num_head, input_size) self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout) self.drop = TimestepDropout(dropout)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
sqrt = math.sqrt sqrt = math.sqrt
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size)))
nn.init.xavier_normal_(self.out.weight) nn.init.xavier_normal_(self.out.weight)
def forward(self, Q, K, V, atte_mask_out=None): def forward(self, Q, K, V, atte_mask_out=None):
""" """


@@ -87,7 +94,7 @@ class MultiHeadAttention(nn.Module):
q = self.q_in(Q).view(batch, sq, n_head, d_k) q = self.q_in(Q).view(batch, sq, n_head, d_k)
k = self.k_in(K).view(batch, sk, n_head, d_k) k = self.k_in(K).view(batch, sk, n_head, d_k)
v = self.v_in(V).view(batch, sk, n_head, d_v) v = self.v_in(V).view(batch, sk, n_head, d_v)
# transpose q, k and v to do batch attention # transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k)
@@ -95,7 +102,7 @@ class MultiHeadAttention(nn.Module):
if atte_mask_out is not None: if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v)
# concat all heads, do output linear # concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.drop(self.out(atte)) output = self.drop(self.out(atte))
@@ -104,6 +111,10 @@ class MultiHeadAttention(nn.Module):


class BiAttention(nn.Module): class BiAttention(nn.Module):
r"""Bi Attention module r"""Bi Attention module
.. todo::
这个模块的负责人来继续完善一下
Calculate Bi Attention matrix `e` Calculate Bi Attention matrix `e`
.. math:: .. math::
@@ -115,11 +126,11 @@ class BiAttention(nn.Module):
\end{array} \end{array}
""" """
def __init__(self): def __init__(self):
super(BiAttention, self).__init__() super(BiAttention, self).__init__()
self.inf = 10e12 self.inf = 10e12
def forward(self, in_x1, in_x2, x1_len, x2_len): def forward(self, in_x1, in_x2, x1_len, x2_len):
""" """
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
@@ -130,36 +141,36 @@ class BiAttention(nn.Module):
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
""" """
assert in_x1.size()[0] == in_x2.size()[0] assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2] assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal. # The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal. # The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]
batch_size = in_x1.size()[0] batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1] x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1] x2_max_len = in_x2.size()[1]
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1) a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -self.inf b_mask = x2_len.le(0.5).float() * -self.inf
b_mask = b_mask.view(batch_size, -1, x2_max_len) b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2) attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]
return out_x1, out_x2 return out_x1, out_x2




@@ -173,10 +184,10 @@ class SelfAttention(nn.Module):
:param float drop: dropout概率,默认值为0.5 :param float drop: dropout概率,默认值为0.5
:param str initial_method: 初始化参数方法 :param str initial_method: 初始化参数方法
""" """
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None,):
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.attention_hops = attention_hops self.attention_hops = attention_hops
self.ws1 = nn.Linear(input_size, attention_unit, bias=False) self.ws1 = nn.Linear(input_size, attention_unit, bias=False)
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False)
@@ -185,7 +196,7 @@ class SelfAttention(nn.Module):
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop)
self.tanh = nn.Tanh() self.tanh = nn.Tanh()
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def _penalization(self, attention): def _penalization(self, attention):
""" """
compute the penalization term for attention module compute the penalization term for attention module
@@ -199,7 +210,7 @@ class SelfAttention(nn.Module):
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)]
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5
return torch.sum(ret) / size[0] return torch.sum(ret) / size[0]
def forward(self, input, input_origin): def forward(self, input, input_origin):
""" """
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 :param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵
@@ -209,15 +220,14 @@ class SelfAttention(nn.Module):
""" """
input = input.contiguous() input = input.contiguous()
size = input.size() # [bsz, len, nhid] size = input.size() # [bsz, len, nhid]
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1, 2).contiguous() attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len] attention = F.softmax(attention, 2) # [baz ,hop, len]
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid]


+ 7
- 2
fastNLP/modules/aggregator/pooling.py View File

@@ -1,4 +1,8 @@
__all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"]
__all__ = [
"MaxPool",
"MaxPoolWithMask",
"AvgPool"
]
import torch import torch
import torch.nn as nn import torch.nn as nn


@@ -16,6 +20,7 @@ class MaxPool(nn.Module):
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension :param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension
:param ceil_mode: :param ceil_mode:
""" """
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False):
super(MaxPool, self).__init__() super(MaxPool, self).__init__()
@@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module):
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling
的时候只会考虑mask为1的位置 的时候只会考虑mask为1的位置
""" """
def __init__(self): def __init__(self):
super(AvgPoolWithMask, self).__init__() super(AvgPoolWithMask, self).__init__()
self.inf = 10e12 self.inf = 10e12


+ 5
- 5
fastNLP/modules/decoder/__init__.py View File

@@ -1,11 +1,11 @@
from .CRF import ConditionalRandomField
from .MLP import MLP
from .utils import viterbi_decode
from .CRF import allowed_transitions

__all__ = [ __all__ = [
"MLP", "MLP",
"ConditionalRandomField", "ConditionalRandomField",
"viterbi_decode", "viterbi_decode",
"allowed_transitions" "allowed_transitions"
] ]

from .crf import ConditionalRandomField
from .mlp import MLP
from .utils import viterbi_decode
from .crf import allowed_transitions

fastNLP/modules/decoder/CRF.py → fastNLP/modules/decoder/crf.py View File

@@ -1,3 +1,8 @@
__all__ = [
"ConditionalRandomField",
"allowed_transitions"
]

import torch import torch
from torch import nn from torch import nn


@@ -6,7 +11,7 @@ from ..utils import initial_parameter


def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
""" """
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.CRF.allowed_transitions`
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions`


给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。


@@ -15,8 +20,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 :param str encoding_type: 支持"bio", "bmes", "bmeso"。
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx);
start_idx=len(id2label), end_idx=len(id2label)+1。
为False, 返回的结果中不含与开始结尾相关的内容
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。
""" """
num_tags = len(id2target) num_tags = len(id2target)
@@ -27,6 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
id_label_lst = list(id2target.items()) id_label_lst = list(id2target.items())
if include_start_end: if include_start_end:
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] id_label_lst += [(start_idx, 'start'), (end_idx, 'end')]
def split_tag_label(from_label): def split_tag_label(from_label):
from_label = from_label.lower() from_label = from_label.lower()
if from_label in ['start', 'end']: if from_label in ['start', 'end']:
@@ -36,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
from_tag = from_label[:1] from_tag = from_label[:1]
from_label = from_label[2:] from_label = from_label[2:]
return from_tag, from_label return from_tag, from_label
for from_id, from_label in id_label_lst: for from_id, from_label in id_label_lst:
if from_label in ['<pad>', '<unk>']: if from_label in ['<pad>', '<unk>']:
continue continue
@@ -60,7 +65,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
:param str to_label: 比如"PER", "LOC"等label :param str to_label: 比如"PER", "LOC"等label
:return: bool,能否跃迁 :return: bool,能否跃迁
""" """
if to_tag=='start' or from_tag=='end':
if to_tag == 'start' or from_tag == 'end':
return False return False
encoding_type = encoding_type.lower() encoding_type = encoding_type.lower()
if encoding_type == 'bio': if encoding_type == 'bio':
@@ -83,12 +88,12 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
if from_tag == 'start': if from_tag == 'start':
return to_tag in ('b', 'o') return to_tag in ('b', 'o')
elif from_tag in ['b', 'i']: elif from_tag in ['b', 'i']:
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label])
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label])
elif from_tag == 'o': elif from_tag == 'o':
return to_tag in ['end', 'b', 'o'] return to_tag in ['end', 'b', 'o']
else: else:
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag))
elif encoding_type == 'bmes': elif encoding_type == 'bmes':
""" """
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转
@@ -111,9 +116,9 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
if from_tag == 'start': if from_tag == 'start':
return to_tag in ['b', 's'] return to_tag in ['b', 's']
elif from_tag == 'b': elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label==to_label
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag == 'm': elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label==to_label
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag in ['e', 's']: elif from_tag in ['e', 's']:
return to_tag in ['b', 's', 'end'] return to_tag in ['b', 's', 'end']
else: else:
@@ -122,21 +127,21 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
if from_tag == 'start': if from_tag == 'start':
return to_tag in ['b', 's', 'o'] return to_tag in ['b', 's', 'o']
elif from_tag == 'b': elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label==to_label
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag == 'm': elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label==to_label
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag in ['e', 's', 'o']: elif from_tag in ['e', 's', 'o']:
return to_tag in ['b', 's', 'end', 'o'] return to_tag in ['b', 's', 'end', 'o']
else: else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag))
else: else:
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type))




class ConditionalRandomField(nn.Module): class ConditionalRandomField(nn.Module):
""" """
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.CRF.ConditionalRandomField`
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField`


条件随机场。 条件随机场。
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。
@@ -148,30 +153,31 @@ class ConditionalRandomField(nn.Module):
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
:param str initial_method: 初始化方法。见initial_parameter :param str initial_method: 初始化方法。见initial_parameter
""" """
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None,
initial_method=None): initial_method=None):
super(ConditionalRandomField, self).__init__() super(ConditionalRandomField, self).__init__()
self.include_start_end_trans = include_start_end_trans self.include_start_end_trans = include_start_end_trans
self.num_tags = num_tags self.num_tags = num_tags
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
if self.include_start_end_trans: if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(num_tags)) self.start_scores = nn.Parameter(torch.randn(num_tags))
self.end_scores = nn.Parameter(torch.randn(num_tags)) self.end_scores = nn.Parameter(torch.randn(num_tags))
if allowed_transitions is None: if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2) constrain = torch.zeros(num_tags + 2, num_tags + 2)
else: else:
constrain = torch.full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float)
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float)
for from_tag_id, to_tag_id in allowed_transitions: for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0 constrain[from_tag_id, to_tag_id] = 0
self._constrain = nn.Parameter(constrain, requires_grad=False) self._constrain = nn.Parameter(constrain, requires_grad=False)
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def _normalizer_likelihood(self, logits, mask): def _normalizer_likelihood(self, logits, mask):
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the """Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences. sum of the likelihoods across all possible state sequences.
@@ -184,21 +190,21 @@ class ConditionalRandomField(nn.Module):
alpha = logits[0] alpha = logits[0]
if self.include_start_end_trans: if self.include_start_end_trans:
alpha = alpha + self.start_scores.view(1, -1) alpha = alpha + self.start_scores.view(1, -1)
flip_mask = mask.eq(0) flip_mask = mask.eq(0)
for i in range(1, seq_len): for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags) emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
if self.include_start_end_trans: if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1) alpha = alpha + self.end_scores.view(1, -1)
return torch.logsumexp(alpha, 1) return torch.logsumexp(alpha, 1)
def _gold_score(self, logits, tags, mask): def _gold_score(self, logits, tags, mask):
""" """
Compute the score for the gold path. Compute the score for the gold path.
@@ -210,15 +216,15 @@ class ConditionalRandomField(nn.Module):
seq_len, batch_size, _ = logits.size() seq_len, batch_size, _ = logits.size()
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
# trans_socre [L-1, B] # trans_socre [L-1, B]
mask = mask.byte() mask = mask.byte()
flip_mask = mask.eq(0) flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B] # emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0)
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# score [L-1, B] # score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = trans_score + emit_score[:seq_len - 1, :]
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0)
if self.include_start_end_trans: if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
@@ -227,24 +233,24 @@ class ConditionalRandomField(nn.Module):
score = score + st_scores + ed_scores score = score + st_scores + ed_scores
# return [B,] # return [B,]
return score return score
def forward(self, feats, tags, mask): def forward(self, feats, tags, mask):
""" """
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。


:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 :param torch.LongTensor tags: batch_size x max_len,标签矩阵。
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。
:return:torch.FloatTensor, (batch_size,)
:return: torch.FloatTensor, (batch_size,)
""" """
feats = feats.transpose(0, 1) feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1).long() tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float() mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask) all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._gold_score(feats, tags, mask) gold_path_score = self._gold_score(feats, tags, mask)
return all_path_score - gold_path_score return all_path_score - gold_path_score
def viterbi_decode(self, logits, mask, unpad=False): def viterbi_decode(self, logits, mask, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数


@@ -259,9 +265,9 @@ class ConditionalRandomField(nn.Module):


""" """
batch_size, seq_len, n_tags = logits.size() batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B
# dp # dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0] vscore = logits[0]
@@ -269,8 +275,8 @@ class ConditionalRandomField(nn.Module):
transitions[:n_tags, :n_tags] += self.trans_m.data transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans: if self.include_start_end_trans:
transitions[n_tags, :n_tags] += self.start_scores.data transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags+1] += self.end_scores.data
transitions[:n_tags, n_tags + 1] += self.end_scores.data
vscore += transitions[n_tags, :n_tags] vscore += transitions[n_tags, :n_tags]
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len): for i in range(1, seq_len):
@@ -280,30 +286,29 @@ class ConditionalRandomField(nn.Module):
best_score, best_dst = score.max(1) best_score, best_dst = score.max(1)
vpath[i] = best_dst vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
if self.include_start_end_trans: if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags+1].view(1, -1)
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
# backtrace # backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1) lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0 # idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1) ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1): for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags] last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i+1], batch_idx] = last_tags
ans[idxes[i + 1], batch_idx] = last_tags
ans = ans.transpose(0, 1) ans = ans.transpose(0, 1)
if unpad: if unpad:
paths = [] paths = []
for idx, seq_len in enumerate(lens): for idx, seq_len in enumerate(lens):
paths.append(ans[idx, :seq_len+1].tolist())
paths.append(ans[idx, :seq_len + 1].tolist())
else: else:
paths = ans paths = ans
return paths, ans_score return paths, ans_score


fastNLP/modules/decoder/MLP.py → fastNLP/modules/decoder/mlp.py View File

@@ -1,3 +1,7 @@
__all__ = [
"MLP"
]

import torch import torch
import torch.nn as nn import torch.nn as nn


@@ -6,17 +10,16 @@ from ..utils import initial_parameter


class MLP(nn.Module): class MLP(nn.Module):
""" """
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP.MLP`
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP`


多层感知器 多层感知器


:param list size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1
:param str or list activation:
一个字符串或者函数或者字符串跟函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu
:param str or function output_activation : 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数
:param str initial_method: 参数初始化方式 :param str initial_method: 参数初始化方式
:param float dropout: dropout概率,默认值为0 :param float dropout: dropout概率,默认值为0
.. note:: .. note::
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义;
@@ -35,10 +38,8 @@ class MLP(nn.Module):
>>> y = net(x) >>> y = net(x)
>>> print(x) >>> print(x)
>>> print(y) >>> print(y)
>>>

""" """
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0):
super(MLP, self).__init__() super(MLP, self).__init__()
self.hiddens = nn.ModuleList() self.hiddens = nn.ModuleList()
@@ -46,12 +47,12 @@ class MLP(nn.Module):
self.output_activation = output_activation self.output_activation = output_activation
for i in range(1, len(size_layer)): for i in range(1, len(size_layer)):
if i + 1 == len(size_layer): if i + 1 == len(size_layer):
self.output = nn.Linear(size_layer[i-1], size_layer[i])
self.output = nn.Linear(size_layer[i - 1], size_layer[i])
else: else:
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i]))
self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i]))
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
actives = { actives = {
'relu': nn.ReLU(), 'relu': nn.ReLU(),
'tanh': nn.Tanh(), 'tanh': nn.Tanh(),
@@ -80,7 +81,7 @@ class MLP(nn.Module):
else: else:
raise ValueError("should set activation correctly: {}".format(activation)) raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
""" """
:param torch.Tensor x: MLP接受的输入 :param torch.Tensor x: MLP接受的输入
@@ -93,16 +94,3 @@ class MLP(nn.Module):
x = self.output_activation(x) x = self.output_activation(x)
x = self.dropout(x) x = self.dropout(x)
return x return x


if __name__ == '__main__':
net1 = MLP([5, 10, 5])
net2 = MLP([5, 10, 5], 'tanh')
net3 = MLP([5, 6, 7, 8, 5], 'tanh')
net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh')
net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh')
for net in [net1, net2, net3, net4, net5]:
x = torch.randn(5, 5)
y = net(x)
print(x)
print(y)

+ 13
- 10
fastNLP/modules/decoder/utils.py View File

@@ -1,10 +1,12 @@
__all__ = ["viterbi_decode"]
__all__ = [
"viterbi_decode"
]
import torch import torch




def viterbi_decode(logits, transitions, mask=None, unpad=False): def viterbi_decode(logits, transitions, mask=None, unpad=False):
"""
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode
r"""
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode`


给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数


@@ -20,18 +22,19 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):


""" """
batch_size, seq_len, n_tags = logits.size() batch_size, seq_len, n_tags = logits.size()
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \
"compatible."
assert n_tags == transitions.size(0) and n_tags == transitions.size(
1), "The shapes of transitions and feats are not " \
"compatible."
logits = logits.transpose(0, 1).data # L, B, H logits = logits.transpose(0, 1).data # L, B, H
if mask is not None: if mask is not None:
mask = mask.transpose(0, 1).data.byte() # L, B mask = mask.transpose(0, 1).data.byte() # L, B
else: else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)
# dp # dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0] vscore = logits[0]
trans_score = transitions.view(1, n_tags, n_tags).data trans_score = transitions.view(1, n_tags, n_tags).data
for i in range(1, seq_len): for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1) prev_score = vscore.view(batch_size, n_tags, 1)
@@ -41,14 +44,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
vpath[i] = best_dst vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0) vscore.masked_fill(mask[i].view(batch_size, 1), 0)
# backtrace # backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1) lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0 # idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1) ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags ans[idxes[0], batch_idx] = last_tags
@@ -62,4 +65,4 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
paths.append(ans[idx, :seq_len + 1].tolist()) paths.append(ans[idx, :seq_len + 1].tolist())
else: else:
paths = ans paths = ans
return paths, ans_score
return paths, ans_score

+ 4
- 2
fastNLP/modules/dropout.py View File

@@ -1,6 +1,8 @@
import torch
__all__ = [] __all__ = []


import torch


class TimestepDropout(torch.nn.Dropout): class TimestepDropout(torch.nn.Dropout):
""" """
别名::class:`fastNLP.modules.TimestepDropout` 别名::class:`fastNLP.modules.TimestepDropout`
@@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout):
接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) 接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``)
在每个timestamp上做dropout。 在每个timestamp上做dropout。
""" """
def forward(self, x): def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)


+ 25
- 8
fastNLP/modules/encoder/__init__.py View File

@@ -1,11 +1,28 @@
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .lstm import LSTM
from .bert import BertModel

__all__ = [ __all__ = [
"LSTM",
"Embedding",
# "BertModel",
"ConvolutionCharEncoder",
"LSTMCharEncoder",
"ConvMaxpool", "ConvMaxpool",
"BertModel"
"Embedding",
"LSTM",
"StarTransformer",
"TransformerEncoder",
"VarRNN",
"VarLSTM",
"VarGRU"
] ]
from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .lstm import LSTM
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU

+ 22
- 9
fastNLP/modules/encoder/char_encoder.py View File

@@ -1,5 +1,9 @@
__all__ = [
"ConvolutionCharEncoder",
"LSTMCharEncoder"
]
import torch import torch
from torch import nn
import torch.nn as nn


from ..utils import initial_parameter from ..utils import initial_parameter


@@ -10,20 +14,22 @@ class ConvolutionCharEncoder(nn.Module):
别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder` 别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder`


char级别的卷积编码器. char级别的卷积编码器.
:param int char_emb_size: char级别embedding的维度. Default: 50 :param int char_emb_size: char级别embedding的维度. Default: 50
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
:例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
:param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter.
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核.
:param initial_method: 初始化参数的方式, 默认为`xavier normal` :param initial_method: 初始化参数的方式, 默认为`xavier normal`
""" """
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None):
super(ConvolutionCharEncoder, self).__init__() super(ConvolutionCharEncoder, self).__init__()
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
for i in range(len(kernels))]) for i in range(len(kernels))])
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
""" """
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding
@@ -34,7 +40,7 @@ class ConvolutionCharEncoder(nn.Module):
x = x.transpose(2, 3) x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width] # [batch_size*sent_length, channel, height, width]
return self._convolute(x).unsqueeze(2) return self._convolute(x).unsqueeze(2)
def _convolute(self, x): def _convolute(self, x):
feats = [] feats = []
for conv in self.convs: for conv in self.convs:
@@ -50,7 +56,14 @@ class ConvolutionCharEncoder(nn.Module):




class LSTMCharEncoder(nn.Module): class LSTMCharEncoder(nn.Module):
"""char级别基于LSTM的encoder."""
"""
别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.LSTMCharEncoder`

char级别基于LSTM的encoder.
"""
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
""" """
:param int char_emb_size: char级别embedding的维度. Default: 50 :param int char_emb_size: char级别embedding的维度. Default: 50
@@ -60,14 +73,14 @@ class LSTMCharEncoder(nn.Module):
""" """
super(LSTMCharEncoder, self).__init__() super(LSTMCharEncoder, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size self.hidden_size = char_emb_size if hidden_size is None else hidden_size
self.lstm = nn.LSTM(input_size=char_emb_size, self.lstm = nn.LSTM(input_size=char_emb_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=1, num_layers=1,
bias=True, bias=True,
batch_first=True) batch_first=True)
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
""" """
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding
@@ -78,6 +91,6 @@ class LSTMCharEncoder(nn.Module):
h0 = nn.init.orthogonal_(h0) h0 = nn.init.orthogonal_(h0)
c0 = torch.empty(1, batch_size, self.hidden_size) c0 = torch.empty(1, batch_size, self.hidden_size)
c0 = nn.init.orthogonal_(c0) c0 = nn.init.orthogonal_(c0)
_, hidden = self.lstm(x, (h0, c0)) _, hidden = self.lstm(x, (h0, c0))
return hidden[0].squeeze().unsqueeze(2) return hidden[0].squeeze().unsqueeze(2)

+ 15
- 13
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -1,6 +1,6 @@
# python: 3.6
# encoding: utf-8
__all__ = [
"ConvMaxpool"
]
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -27,22 +27,24 @@ class ConvMaxpool(nn.Module):
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
:param str initial_method: str。 :param str initial_method: str。
""" """
def __init__(self, in_channels, out_channels, kernel_sizes, def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True, activation="relu", initial_method=None): groups=1, bias=True, activation="relu", initial_method=None):
super(ConvMaxpool, self).__init__() super(ConvMaxpool, self).__init__()
# convolution # convolution
if isinstance(kernel_sizes, (list, tuple, int)): if isinstance(kernel_sizes, (list, tuple, int)):
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): if isinstance(kernel_sizes, int) and isinstance(out_channels, int):
out_channels = [out_channels] out_channels = [out_channels]
kernel_sizes = [kernel_sizes] kernel_sizes = [kernel_sizes]
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)):
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \
" of kernel_sizes."
assert len(out_channels) == len(
kernel_sizes), "The number of out_channels should be equal to the number" \
" of kernel_sizes."
else: else:
raise ValueError("The type of out_channels and kernel_sizes should be the same.") raise ValueError("The type of out_channels and kernel_sizes should be the same.")
self.convs = nn.ModuleList([nn.Conv1d( self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,
out_channels=oc, out_channels=oc,
@@ -53,11 +55,11 @@ class ConvMaxpool(nn.Module):
groups=groups, groups=groups,
bias=bias) bias=bias)
for oc, ks in zip(out_channels, kernel_sizes)]) for oc, ks in zip(out_channels, kernel_sizes)])
else: else:
raise Exception( raise Exception(
'Incorrect kernel sizes: should be list, tuple or int') 'Incorrect kernel sizes: should be list, tuple or int')
# activation function # activation function
if activation == 'relu': if activation == 'relu':
self.activation = F.relu self.activation = F.relu
@@ -68,9 +70,9 @@ class ConvMaxpool(nn.Module):
else: else:
raise Exception( raise Exception(
"Undefined activation function: choose from: relu, tanh, sigmoid") "Undefined activation function: choose from: relu, tanh, sigmoid")
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def forward(self, x, mask=None): def forward(self, x, mask=None):
""" """


@@ -83,9 +85,9 @@ class ConvMaxpool(nn.Module):
# convolution # convolution
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...]
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1) # B x 1 x L
mask = mask.unsqueeze(1) # B x 1 x L
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] xs = [x.masked_fill_(mask, float('-inf')) for x in xs]
# max-pooling # max-pooling
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
for i in xs] # [[N, C], ...] for i in xs] # [[N, C], ...]
return torch.cat(xs, dim=-1) # [N, C]
return torch.cat(xs, dim=-1) # [N, C]

+ 11
- 7
fastNLP/modules/encoder/embedding.py View File

@@ -1,14 +1,18 @@
__all__ = [
"Embedding"
]
import torch.nn as nn import torch.nn as nn
from ..utils import get_embeddings from ..utils import get_embeddings



class Embedding(nn.Embedding): class Embedding(nn.Embedding):
""" """
别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` 别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding`


Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度"""
def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2,
scale_grad_by_freq=False):
scale_grad_by_freq=False):
""" """


:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
@@ -22,14 +26,14 @@ class Embedding(nn.Embedding):
""" """
embed = get_embeddings(init_embed) embed = get_embeddings(init_embed)
num_embeddings, embedding_dim = embed.weight.size() num_embeddings, embedding_dim = embed.weight.size()
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx,
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse, _weight=embed.weight.data)
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse, _weight=embed.weight.data)
del embed del embed
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
""" """
:param torch.LongTensor x: [batch, seq_len] :param torch.LongTensor x: [batch, seq_len]


+ 8
- 2
fastNLP/modules/encoder/lstm.py View File

@@ -1,6 +1,11 @@
"""轻量封装的 Pytorch LSTM 模块.
"""
轻量封装的 Pytorch LSTM 模块.
可在 forward 时传入序列的长度, 自动对padding做合适的处理. 可在 forward 时传入序列的长度, 自动对padding做合适的处理.
""" """
__all__ = [
"LSTM"
]

import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils.rnn as rnn import torch.nn.utils.rnn as rnn
@@ -23,6 +28,7 @@ class LSTM(nn.Module):
:(batch, seq, feature). Default: ``False`` :(batch, seq, feature). Default: ``False``
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
""" """
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None): bidirectional=False, bias=True, initial_method=None):
super(LSTM, self).__init__() super(LSTM, self).__init__()
@@ -30,7 +36,7 @@ class LSTM(nn.Module):
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional) dropout=dropout, bidirectional=bidirectional)
initial_parameter(self, initial_method) initial_parameter(self, initial_method)
def forward(self, x, seq_len=None, h0=None, c0=None): def forward(self, x, seq_len=None, h0=None, c0=None):
""" """




+ 39
- 32
fastNLP/modules/encoder/star_transformer.py View File

@@ -1,9 +1,14 @@
"""Star-Transformer 的encoder部分的 Pytorch 实现
""" """
Star-Transformer 的encoder部分的 Pytorch 实现
"""
__all__ = [
"StarTransformer"
]

import numpy as NP
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import numpy as NP




class StarTransformer(nn.Module): class StarTransformer(nn.Module):
@@ -24,10 +29,11 @@ class StarTransformer(nn.Module):
模型会为输入序列加上position embedding。 模型会为输入序列加上position embedding。
若为`None`,忽略加上position embedding的步骤. Default: `None` 若为`None`,忽略加上position embedding的步骤. Default: `None`
""" """
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
super(StarTransformer, self).__init__() super(StarTransformer, self).__init__()
self.iters = num_layers self.iters = num_layers
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)])
self.ring_att = nn.ModuleList( self.ring_att = nn.ModuleList(
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
@@ -35,12 +41,12 @@ class StarTransformer(nn.Module):
self.star_att = nn.ModuleList( self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
for _ in range(self.iters)]) for _ in range(self.iters)])
if max_len is not None: if max_len is not None:
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size)
else: else:
self.pos_emb = None self.pos_emb = None
def forward(self, data, mask): def forward(self, data, mask):
""" """
:param FloatTensor data: [batch, length, hidden] 输入的序列 :param FloatTensor data: [batch, length, hidden] 输入的序列
@@ -50,20 +56,21 @@ class StarTransformer(nn.Module):


[batch, hidden] 全局 relay 节点, 详见论文 [batch, hidden] 全局 relay 节点, 详见论文
""" """
def norm_func(f, x): def norm_func(f, x):
# B, H, L, 1 # B, H, L, 1
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
B, L, H = data.size() B, L, H = data.size()
mask = (mask == 0) # flip the mask for masked_fill_
mask = (mask == 0) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb: if self.pos_emb:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device)\
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P embs = embs + P
nodes = embs nodes = embs
relay = embs.mean(2, keepdim=True) relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
@@ -72,11 +79,11 @@ class StarTransformer(nn.Module):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
nodes = nodes.masked_fill_(ex_mask, 0) nodes = nodes.masked_fill_(ex_mask, 0)
nodes = nodes.view(B, H, L).permute(0, 2, 1) nodes = nodes.view(B, H, L).permute(0, 2, 1)
return nodes, relay.view(B, H) return nodes, relay.view(B, H)




@@ -89,37 +96,37 @@ class _MSA1(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
def forward(self, x, ax=None): def forward(self, x, ax=None):
# x: B, H, L, 1, ax : B, H, X, L append features # x: B, H, L, 1, ax : B, H, X, L append features
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = x.shape B, H, L, _ = x.shape
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1)
if ax is not None: if ax is not None:
aL = ax.shape[2] aL = ax.shape[2]
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
av = self.WV(ax).view(B, nhead, head_dim, aL, L) av = self.WV(ax).view(B, nhead, head_dim, aL, L)
q = q.view(B, nhead, head_dim, 1, L) q = q.view(B, nhead, head_dim, 1, L)
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\
.view(B, nhead, head_dim, unfold_size, L)
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\
.view(B, nhead, head_dim, unfold_size, L)
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
.view(B, nhead, head_dim, unfold_size, L)
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
.view(B, nhead, head_dim, unfold_size, L)
if ax is not None: if ax is not None:
k = torch.cat([k, ak], 3) k = torch.cat([k, ak], 3)
v = torch.cat([v, av], 3) v = torch.cat([v, av], 3)
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)
ret = self.WO(att) ret = self.WO(att)
return ret return ret




@@ -131,19 +138,19 @@ class _MSA2(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
def forward(self, x, y, mask=None): def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1 # x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = y.shape B, H, L, _ = y.shape
q, k, v = self.WQ(x), self.WK(y), self.WV(y) q, k, v = self.WQ(x), self.WK(y), self.WV(y)
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h


+ 10
- 6
fastNLP/modules/encoder/transformer.py View File

@@ -1,3 +1,6 @@
__all__ = [
"TransformerEncoder"
]
from torch import nn from torch import nn


from ..aggregator.attention import MultiHeadAttention from ..aggregator.attention import MultiHeadAttention
@@ -19,6 +22,7 @@ class TransformerEncoder(nn.Module):
:param int num_head: head的数量。 :param int num_head: head的数量。
:param float dropout: dropout概率. Default: 0.1 :param float dropout: dropout概率. Default: 0.1
""" """
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
@@ -27,9 +31,9 @@ class TransformerEncoder(nn.Module):
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(inner_size, model_size), nn.Linear(inner_size, model_size),
TimestepDropout(dropout),)
TimestepDropout(dropout), )
self.norm2 = nn.LayerNorm(model_size) self.norm2 = nn.LayerNorm(model_size)
def forward(self, input, seq_mask=None, atte_mask_out=None): def forward(self, input, seq_mask=None, atte_mask_out=None):
""" """


@@ -44,11 +48,11 @@ class TransformerEncoder(nn.Module):
output = self.norm2(output + norm_atte) output = self.norm2(output + norm_atte)
output *= seq_mask output *= seq_mask
return output return output
def __init__(self, num_layers, **kargs): def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
""" """
:param x: [batch, seq_len, model_size] 输入序列 :param x: [batch, seq_len, model_size] 输入序列
@@ -60,8 +64,8 @@ class TransformerEncoder(nn.Module):
if seq_mask is None: if seq_mask is None:
atte_mask_out = None atte_mask_out = None
else: else:
atte_mask_out = (seq_mask < 1)[:,None,:]
seq_mask = seq_mask[:,:,None]
atte_mask_out = (seq_mask < 1)[:, None, :]
seq_mask = seq_mask[:, :, None]
for layer in self.layers: for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out) output = layer(output, seq_mask, atte_mask_out)
return output return output

+ 71
- 43
fastNLP/modules/encoder/variational_rnn.py View File

@@ -1,9 +1,15 @@
"""Variational RNN 的 Pytorch 实现
""" """
Variational RNN 的 Pytorch 实现
"""
__all__ = [
"VarRNN",
"VarLSTM",
"VarGRU"
]

import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from ..utils import initial_parameter


try: try:
from torch import flip from torch import flip
@@ -11,21 +17,25 @@ except ImportError:
def flip(x, dims): def flip(x, dims):
indices = [slice(None)] * x.dim() indices = [slice(None)] * x.dim()
for dim in dims: for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
indices[dim] = torch.arange(
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)] return x[tuple(indices)]


from ..utils import initial_parameter



class VarRnnCellWrapper(nn.Module): class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout
""" """

Wrapper for normal RNN Cells, make it support variational dropout
"""
def __init__(self, cell, hidden_size, input_p, hidden_p): def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__() super(VarRnnCellWrapper, self).__init__()
self.cell = cell self.cell = cell
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.input_p = input_p self.input_p = input_p
self.hidden_p = hidden_p self.hidden_p = hidden_p
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False):
""" """
:param PackedSequence input_x: [seq_len, batch_size, input_size] :param PackedSequence input_x: [seq_len, batch_size, input_size]
@@ -37,11 +47,13 @@ class VarRnnCellWrapper(nn.Module):
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size] for other RNN, h_n, [batch_size, hidden_size]
""" """
def get_hi(hi, h0, size): def get_hi(hi, h0, size):
h0_size = size - hi.size(0) h0_size = size - hi.size(0)
if h0_size > 0: if h0_size > 0:
return torch.cat([hi, h0[:h0_size]], dim=0) return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size] return hi[:size]
is_lstm = isinstance(hidden, tuple) is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x.data, input_x.batch_sizes input, batch_sizes = input_x.data, input_x.batch_sizes
output = [] output = []
@@ -52,7 +64,7 @@ class VarRnnCellWrapper(nn.Module):
else: else:
batch_iter = batch_sizes batch_iter = batch_sizes
idx = 0 idx = 0
if is_lstm: if is_lstm:
hn = (hidden[0].clone(), hidden[1].clone()) hn = (hidden[0].clone(), hidden[1].clone())
else: else:
@@ -60,15 +72,16 @@ class VarRnnCellWrapper(nn.Module):
hi = hidden hi = hidden
for size in batch_iter: for size in batch_iter:
if is_reversed: if is_reversed:
input_i = input[idx-size: idx] * mask_x[:size]
input_i = input[idx - size: idx] * mask_x[:size]
idx -= size idx -= size
else: else:
input_i = input[idx: idx+size] * mask_x[:size]
input_i = input[idx: idx + size] * mask_x[:size]
idx += size idx += size
mask_hi = mask_h[:size] mask_hi = mask_h[:size]
if is_lstm: if is_lstm:
hx, cx = hi hx, cx = hi
hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size))
hi = (get_hi(hx, hidden[0], size) *
mask_hi, get_hi(cx, hidden[1], size))
hi = cell(input_i, hi) hi = cell(input_i, hi)
hn[0][:size] = hi[0] hn[0][:size] = hi[0]
hn[1][:size] = hi[1] hn[1][:size] = hi[1]
@@ -78,7 +91,7 @@ class VarRnnCellWrapper(nn.Module):
hi = cell(input_i, hi) hi = cell(input_i, hi)
hn[:size] = hi hn[:size] = hi
output.append(hi) output.append(hi)
if is_reversed: if is_reversed:
output = list(reversed(output)) output = list(reversed(output))
output = torch.cat(output, dim=0) output = torch.cat(output, dim=0)
@@ -86,7 +99,9 @@ class VarRnnCellWrapper(nn.Module):




class VarRNNBase(nn.Module): class VarRNNBase(nn.Module):
"""Variational Dropout RNN 实现.
"""
Variational Dropout RNN 实现.

论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`. https://arxiv.org/abs/1512.05287`.


@@ -102,7 +117,7 @@ class VarRNNBase(nn.Module):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
""" """
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False, bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False): input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -122,18 +137,20 @@ class VarRNNBase(nn.Module):
for direction in range(self.num_directions): for direction in range(self.num_directions):
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
cell = Cell(input_size, self.hidden_size, bias) cell = Cell(input_size, self.hidden_size, bias)
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout))
self._all_cells.append(VarRnnCellWrapper(
cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self) initial_parameter(self)
self.is_lstm = (self.mode == "LSTM") self.is_lstm = (self.mode == "LSTM")
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h):
is_lstm = self.is_lstm is_lstm = self.is_lstm
idx = self.num_directions * n_layer + n_direction idx = self.num_directions * n_layer + n_direction
cell = self._all_cells[idx] cell = self._all_cells[idx]
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
output_x, hidden_x = cell(
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
return output_x, hidden_x return output_x, hidden_x
def forward(self, x, hx=None): def forward(self, x, hx=None):
""" """


@@ -147,31 +164,38 @@ class VarRNNBase(nn.Module):
if not is_packed: if not is_packed:
seq_len = x.size(1) if self.batch_first else x.size(0) seq_len = x.size(1) if self.batch_first else x.size(0)
max_batch_size = x.size(0) if self.batch_first else x.size(1) max_batch_size = x.size(0) if self.batch_first else x.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
seq_lens = torch.LongTensor(
[seq_len for _ in range(max_batch_size)])
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
else: else:
max_batch_size = int(input.batch_sizes[0])
input, batch_sizes = input.data, input.batch_sizes
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes
if hx is None: if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions, hx = x.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size, requires_grad=True) max_batch_size, self.hidden_size, requires_grad=True)
if is_lstm: if is_lstm:
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) hx = (hx, hx.new_zeros(hx.size(), requires_grad=True))
mask_x = x.new_ones((max_batch_size, self.input_size)) mask_x = x.new_ones((max_batch_size, self.input_size))
mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions))
mask_out = x.new_ones(
(max_batch_size, self.hidden_size * self.num_directions))
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) mask_h_ones = x.new_ones((max_batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True)

hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout,
training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout,
training=self.training, inplace=True)
hidden = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
if is_lstm: if is_lstm:
cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
cellstate = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
for layer in range(self.num_layers): for layer in range(self.num_layers):
output_list = [] output_list = []
input_seq = PackedSequence(x, batch_sizes) input_seq = PackedSequence(x, batch_sizes)
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False)
mask_h = nn.functional.dropout(
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False)
for direction in range(self.num_directions): for direction in range(self.num_directions):
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx,
mask_x if layer == 0 else mask_out, mask_h) mask_x if layer == 0 else mask_out, mask_h)
@@ -183,18 +207,19 @@ class VarRNNBase(nn.Module):
else: else:
hidden[idx] = hidden_x hidden[idx] = hidden_x
x = torch.cat(output_list, dim=-1) x = torch.cat(output_list, dim=-1)
if is_lstm: if is_lstm:
hidden = (hidden, cellstate) hidden = (hidden, cellstate)
if is_packed: if is_packed:
output = PackedSequence(x, batch_sizes) output = PackedSequence(x, batch_sizes)
else: else:
x = PackedSequence(x, batch_sizes) x = PackedSequence(x, batch_sizes)
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) output, _ = pad_packed_sequence(x, batch_first=self.batch_first)
return output, hidden return output, hidden



class VarLSTM(VarRNNBase): class VarLSTM(VarRNNBase):
""" """
别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM` 别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM`
@@ -211,10 +236,11 @@ class VarLSTM(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False``
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)

super(VarLSTM, self).__init__(
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)
def forward(self, x, hx=None): def forward(self, x, hx=None):
return super(VarLSTM, self).forward(x, hx) return super(VarLSTM, self).forward(x, hx)


@@ -235,13 +261,15 @@ class VarRNN(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)

super(VarRNN, self).__init__(
mode="RNN", Cell=nn.RNNCell, *args, **kwargs)
def forward(self, x, hx=None): def forward(self, x, hx=None):
return super(VarRNN, self).forward(x, hx) return super(VarRNN, self).forward(x, hx)



class VarGRU(VarRNNBase): class VarGRU(VarRNNBase):
""" """
别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU` 别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU`
@@ -258,10 +286,10 @@ class VarGRU(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False``
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

super(VarGRU, self).__init__(
mode="GRU", Cell=nn.GRUCell, *args, **kwargs)
def forward(self, x, hx=None): def forward(self, x, hx=None):
return super(VarGRU, self).forward(x, hx) return super(VarGRU, self).forward(x, hx)


+ 1
- 1
fastNLP/modules/utils.py View File

@@ -1,5 +1,5 @@
from functools import reduce from functools import reduce
from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn


+ 3
- 3
reproduction/Chinese_word_segmentation/models/cws_model.py View File

@@ -3,7 +3,7 @@ import torch
from torch import nn from torch import nn


from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.decoder.MLP import MLP
from fastNLP.modules.decoder.mlp import MLP
from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask




@@ -120,8 +120,8 @@ class CWSBiLSTMSegApp(BaseModel):
return {'pred_tags': pred_tags} return {'pred_tags': pred_tags}




from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.decoder.crf import ConditionalRandomField
from fastNLP.modules.decoder.crf import allowed_transitions


class CWSBiLSTMCRF(BaseModel): class CWSBiLSTMCRF(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,


+ 2
- 2
reproduction/Chinese_word_segmentation/models/cws_transformer.py View File

@@ -10,8 +10,8 @@ from torch import nn
import torch import torch
# from fastNLP.modules.encoder.transformer import TransformerEncoder # from fastNLP.modules.encoder.transformer import TransformerEncoder
from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.decoder.crf import ConditionalRandomField,seq_len_to_byte_mask
from fastNLP.modules.decoder.crf import allowed_transitions


class TransformerCWS(nn.Module): class TransformerCWS(nn.Module):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,


+ 1
- 1
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -7,7 +7,7 @@ from fastNLP.io.config_io import ConfigSection
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention from fastNLP.modules.aggregator.self_attention import SelfAttention
from fastNLP.modules.decoder.MLP import MLP
from fastNLP.modules.decoder.mlp import MLP
from fastNLP.modules.encoder.embedding import Embedding as Embedding from fastNLP.modules.encoder.embedding import Embedding as Embedding
from fastNLP.modules.encoder.lstm import LSTM from fastNLP.modules.encoder.lstm import LSTM




+ 5
- 5
test/modules/decoder/test_CRF.py View File

@@ -5,7 +5,7 @@ import unittest
class TestCRF(unittest.TestCase): class TestCRF(unittest.TestCase):
def test_case1(self): def test_case1(self):
# 检查allowed_transitions()能否正确使用 # 检查allowed_transitions()能否正确使用
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.decoder.crf import allowed_transitions


id2label = {0: 'B', 1: 'I', 2:'O'} id2label = {0: 'B', 1: 'I', 2:'O'}
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
@@ -43,7 +43,7 @@ class TestCRF(unittest.TestCase):
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
pass pass
# import torch # import torch
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask
# from fastNLP.modules.decoder.crf import seq_len_to_byte_mask
# #
# labels = ['O'] # labels = ['O']
# for label in ['X', 'Y']: # for label in ['X', 'Y']:
@@ -63,7 +63,7 @@ class TestCRF(unittest.TestCase):
# mask = seq_len_to_byte_mask(seq_lens) # mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask) # allen_res = allen_CRF.viterbi_tags(logits, mask)
# #
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
# fast_CRF.trans_m = trans_m # fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
@@ -91,7 +91,7 @@ class TestCRF(unittest.TestCase):
# mask = seq_len_to_byte_mask(seq_lens) # mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask) # allen_res = allen_CRF.viterbi_tags(logits, mask)
# #
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES')) # encoding_type='BMES'))
# fast_CRF.trans_m = trans_m # fast_CRF.trans_m = trans_m
@@ -104,7 +104,7 @@ class TestCRF(unittest.TestCase):
def test_case3(self): def test_case3(self):
# 测试crf的loss不会出现负数 # 测试crf的loss不会出现负数
import torch import torch
from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.modules.decoder.crf import ConditionalRandomField
from fastNLP.core.utils import seq_len_to_mask from fastNLP.core.utils import seq_len_to_mask
from torch import optim from torch import optim
from torch import nn from torch import nn


Loading…
Cancel
Save