@@ -1,7 +0,0 @@ | |||||
fastNLP.models.base\_model | |||||
========================== | |||||
.. automodule:: fastNLP.models.base_model | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||||
fastNLP.models.bert | |||||
=================== | |||||
.. automodule:: fastNLP.models.bert | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||||
fastNLP.models.enas\_controller | |||||
=============================== | |||||
.. automodule:: fastNLP.models.enas_controller | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||||
fastNLP.models.enas\_model | |||||
========================== | |||||
.. automodule:: fastNLP.models.enas_model | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||||
fastNLP.models.enas\_trainer | |||||
============================ | |||||
.. automodule:: fastNLP.models.enas_trainer | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||||
fastNLP.models.enas\_utils | |||||
========================== | |||||
.. automodule:: fastNLP.models.enas_utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -12,14 +12,8 @@ fastNLP.models | |||||
.. toctree:: | .. toctree:: | ||||
:titlesonly: | :titlesonly: | ||||
fastNLP.models.base_model | |||||
fastNLP.models.bert | |||||
fastNLP.models.biaffine_parser | fastNLP.models.biaffine_parser | ||||
fastNLP.models.cnn_text_classification | fastNLP.models.cnn_text_classification | ||||
fastNLP.models.enas_controller | |||||
fastNLP.models.enas_model | |||||
fastNLP.models.enas_trainer | |||||
fastNLP.models.enas_utils | |||||
fastNLP.models.sequence_labeling | fastNLP.models.sequence_labeling | ||||
fastNLP.models.snli | fastNLP.models.snli | ||||
fastNLP.models.star_transformer | fastNLP.models.star_transformer | ||||
@@ -1,7 +1,7 @@ | |||||
fastNLP.modules.decoder.CRF | fastNLP.modules.decoder.CRF | ||||
=========================== | =========================== | ||||
.. automodule:: fastNLP.modules.decoder.CRF | |||||
.. automodule:: fastNLP.modules.decoder.crf | |||||
:members: | :members: | ||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: |
@@ -1,7 +1,7 @@ | |||||
fastNLP.modules.decoder.MLP | fastNLP.modules.decoder.MLP | ||||
=========================== | =========================== | ||||
.. automodule:: fastNLP.modules.decoder.MLP | |||||
.. automodule:: fastNLP.modules.decoder.mlp | |||||
:members: | :members: | ||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: |
@@ -12,7 +12,7 @@ fastNLP.modules.decoder | |||||
.. toctree:: | .. toctree:: | ||||
:titlesonly: | :titlesonly: | ||||
fastNLP.modules.decoder.CRF | |||||
fastNLP.modules.decoder.MLP | |||||
fastNLP.modules.decoder.crf | |||||
fastNLP.modules.decoder.mlp | |||||
fastNLP.modules.decoder.utils | fastNLP.modules.decoder.utils | ||||
@@ -52,8 +52,8 @@ __all__ = [ | |||||
"cache_results" | "cache_results" | ||||
] | ] | ||||
__version__ = '0.4.0' | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
__version__ = '0.4.0' |
@@ -2,14 +2,18 @@ | |||||
batch 模块实现了 fastNLP 所需的 Batch 类。 | batch 模块实现了 fastNLP 所需的 Batch 类。 | ||||
""" | """ | ||||
__all__ = ["Batch"] | |||||
import numpy as np | |||||
import torch | |||||
__all__ = [ | |||||
"Batch" | |||||
] | |||||
import atexit | import atexit | ||||
from queue import Empty, Full | |||||
from .sampler import RandomSampler, Sampler | |||||
import numpy as np | |||||
import torch | |||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
from queue import Empty, Full | |||||
from .sampler import RandomSampler | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -120,7 +124,7 @@ class Batch(object): | |||||
:return list(int) indexes: 下标序列 | :return list(int) indexes: 下标序列 | ||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
@staticmethod | @staticmethod | ||||
def _run_fetch(batch, q): | def _run_fetch(batch, q): | ||||
try: | try: | ||||
@@ -145,7 +149,7 @@ class Batch(object): | |||||
q.put(e) | q.put(e) | ||||
finally: | finally: | ||||
q.join() | q.join() | ||||
@staticmethod | @staticmethod | ||||
def _run_batch_iter(batch): | def _run_batch_iter(batch): | ||||
q = mp.JoinableQueue(maxsize=10) | q = mp.JoinableQueue(maxsize=10) | ||||
@@ -182,4 +186,3 @@ def _to_tensor(batch, dtype): | |||||
except: | except: | ||||
pass | pass | ||||
return batch | return batch | ||||
@@ -60,16 +60,20 @@ __all__ = [ | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | "EarlyStopError" | ||||
] | ] | ||||
import os | import os | ||||
import torch | import torch | ||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
try: | try: | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
tensorboardX_flag = True | tensorboardX_flag = True | ||||
except: | except: | ||||
tensorboardX_flag = False | tensorboardX_flag = False | ||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
class Callback(object): | class Callback(object): | ||||
""" | """ | ||||
@@ -587,7 +591,7 @@ class TensorboardCallback(Callback): | |||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
else: | else: | ||||
self._summary_writer = None | self._summary_writer = None | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
if "model" in self.options and self.graph_added is False: | if "model" in self.options and self.graph_added is False: | ||||
# tesorboardX 这里有大bug,暂时没法画模型图 | # tesorboardX 这里有大bug,暂时没法画模型图 | ||||
@@ -272,11 +272,14 @@ | |||||
""" | """ | ||||
__all__ = ["DataSet"] | |||||
__all__ = [ | |||||
"DataSet" | |||||
] | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import warnings | |||||
import numpy as np | import numpy as np | ||||
import warnings | |||||
from .field import AutoPadder | from .field import AutoPadder | ||||
from .field import FieldArray | from .field import FieldArray | ||||
@@ -863,4 +866,4 @@ class DataSet(object): | |||||
with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
d = pickle.load(f) | d = pickle.load(f) | ||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | |||||
return d |
@@ -3,10 +3,16 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas | |||||
原理部分请参考 :doc:`fastNLP.core.dataset` | 原理部分请参考 :doc:`fastNLP.core.dataset` | ||||
""" | """ | ||||
__all__ = [ | |||||
"FieldArray", | |||||
"Padder", | |||||
"AutoPadder", | |||||
"EngChar2DPadder" | |||||
] | |||||
from copy import deepcopy | |||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | |||||
class FieldArray(object): | class FieldArray(object): | ||||
@@ -24,6 +30,7 @@ class FieldArray(object): | |||||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | ||||
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | ||||
""" | """ | ||||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | ||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | if isinstance(content, list): | ||||
@@ -41,7 +48,7 @@ class FieldArray(object): | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | raise RuntimeError("Cannot initialize FieldArray with empty list.") | ||||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | ||||
self.content_dim = None # 表示content是多少维的list | self.content_dim = None # 表示content是多少维的list | ||||
if padder is None: | if padder is None: | ||||
@@ -51,27 +58,27 @@ class FieldArray(object): | |||||
padder = deepcopy(padder) | padder = deepcopy(padder) | ||||
self.set_padder(padder) | self.set_padder(padder) | ||||
self.ignore_type = ignore_type | self.ignore_type = ignore_type | ||||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | ||||
self.pytype = None | self.pytype = None | ||||
self.dtype = None | self.dtype = None | ||||
self._is_input = None | self._is_input = None | ||||
self._is_target = None | self._is_target = None | ||||
if is_input is not None or is_target is not None: | if is_input is not None or is_target is not None: | ||||
self.is_input = is_input | self.is_input = is_input | ||||
self.is_target = is_target | self.is_target = is_target | ||||
def _set_dtype(self): | def _set_dtype(self): | ||||
if self.ignore_type is False: | if self.ignore_type is False: | ||||
self.pytype = self._type_detection(self.content) | self.pytype = self._type_detection(self.content) | ||||
self.dtype = self._map_to_np_type(self.pytype) | self.dtype = self._map_to_np_type(self.pytype) | ||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
return self._is_input | return self._is_input | ||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
""" | """ | ||||
@@ -80,11 +87,11 @@ class FieldArray(object): | |||||
if value is True: | if value is True: | ||||
self._set_dtype() | self._set_dtype() | ||||
self._is_input = value | self._is_input = value | ||||
@property | @property | ||||
def is_target(self): | def is_target(self): | ||||
return self._is_target | return self._is_target | ||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
""" | """ | ||||
@@ -93,7 +100,7 @@ class FieldArray(object): | |||||
if value is True: | if value is True: | ||||
self._set_dtype() | self._set_dtype() | ||||
self._is_target = value | self._is_target = value | ||||
def _type_detection(self, content): | def _type_detection(self, content): | ||||
""" | """ | ||||
当该field被设置为is_input或者is_target时被调用 | 当该field被设置为is_input或者is_target时被调用 | ||||
@@ -101,9 +108,9 @@ class FieldArray(object): | |||||
""" | """ | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | raise RuntimeError("Empty list in Field {}.".format(self.name)) | ||||
type_set = set([type(item) for item in content]) | type_set = set([type(item) for item in content]) | ||||
if list in type_set: | if list in type_set: | ||||
if len(type_set) > 1: | if len(type_set) > 1: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
@@ -139,7 +146,7 @@ class FieldArray(object): | |||||
self.name, self.BASIC_TYPES, content_type)) | self.name, self.BASIC_TYPES, content_type)) | ||||
self.content_dim = 1 | self.content_dim = 1 | ||||
return self._basic_type_detection(type_set) | return self._basic_type_detection(type_set) | ||||
def _basic_type_detection(self, type_set): | def _basic_type_detection(self, type_set): | ||||
""" | """ | ||||
:param type_set: a set of Python types | :param type_set: a set of Python types | ||||
@@ -158,7 +165,7 @@ class FieldArray(object): | |||||
else: | else: | ||||
# str, int, float混在一起 | # str, int, float混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | ||||
def _1d_list_check(self, val): | def _1d_list_check(self, val): | ||||
"""如果不是1D list就报错 | """如果不是1D list就报错 | ||||
""" | """ | ||||
@@ -168,7 +175,7 @@ class FieldArray(object): | |||||
self._basic_type_detection(type_set) | self._basic_type_detection(type_set) | ||||
# otherwise: _basic_type_detection will raise error | # otherwise: _basic_type_detection will raise error | ||||
return True | return True | ||||
def _2d_list_check(self, val): | def _2d_list_check(self, val): | ||||
"""如果不是2D list 就报错 | """如果不是2D list 就报错 | ||||
""" | """ | ||||
@@ -181,15 +188,15 @@ class FieldArray(object): | |||||
inner_type_set.add(type(obj)) | inner_type_set.add(type(obj)) | ||||
self._basic_type_detection(inner_type_set) | self._basic_type_detection(inner_type_set) | ||||
return True | return True | ||||
@staticmethod | @staticmethod | ||||
def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | ||||
return type_mapping[basic_type] | return type_mapping[basic_type] | ||||
def __repr__(self): | def __repr__(self): | ||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | ||||
def append(self, val): | def append(self, val): | ||||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | ||||
的内容是匹配的。 | 的内容是匹配的。 | ||||
@@ -208,7 +215,7 @@ class FieldArray(object): | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | ||||
if self.is_input is True or self.is_target is True: | if self.is_input is True or self.is_target is True: | ||||
if type(val) == list: | if type(val) == list: | ||||
if len(val) == 0: | if len(val) == 0: | ||||
@@ -231,14 +238,14 @@ class FieldArray(object): | |||||
raise RuntimeError( | raise RuntimeError( | ||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | ||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
return self.get(indices, pad=False) | return self.get(indices, pad=False) | ||||
def __setitem__(self, idx, val): | def __setitem__(self, idx, val): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
""" | """ | ||||
根据给定的indices返回内容 | 根据给定的indices返回内容 | ||||
@@ -251,13 +258,13 @@ class FieldArray(object): | |||||
return self.content[indices] | return self.content[indices] | ||||
if self.is_input is False and self.is_target is False: | if self.is_input is False and self.is_target is False: | ||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | ||||
contents = [self.content[i] for i in indices] | contents = [self.content[i] for i in indices] | ||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | else: | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | """ | ||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
@@ -269,7 +276,7 @@ class FieldArray(object): | |||||
self.padder = deepcopy(padder) | self.padder = deepcopy(padder) | ||||
else: | else: | ||||
self.padder = None | self.padder = None | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | """ | ||||
修改padder的pad_val. | 修改padder的pad_val. | ||||
@@ -279,8 +286,7 @@ class FieldArray(object): | |||||
if self.padder is not None: | if self.padder is not None: | ||||
self.padder.set_pad_val(pad_val) | self.padder.set_pad_val(pad_val) | ||||
return self | return self | ||||
def __len__(self): | def __len__(self): | ||||
""" | """ | ||||
Returns the size of FieldArray. | Returns the size of FieldArray. | ||||
@@ -288,7 +294,7 @@ class FieldArray(object): | |||||
:return int length: | :return int length: | ||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | def to(self, other): | ||||
""" | """ | ||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | 将other的属性复制给本FieldArray(other必须为FieldArray类型). | ||||
@@ -298,14 +304,15 @@ class FieldArray(object): | |||||
:return: :class:`~fastNLP.FieldArray` | :return: :class:`~fastNLP.FieldArray` | ||||
""" | """ | ||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | ||||
self.is_input = other.is_input | self.is_input = other.is_input | ||||
self.is_target = other.is_target | self.is_target = other.is_target | ||||
self.padder = other.padder | self.padder = other.padder | ||||
self.ignore_type = other.ignore_type | self.ignore_type = other.ignore_type | ||||
return self | return self | ||||
def _is_iterable(content): | def _is_iterable(content): | ||||
try: | try: | ||||
_ = (e for e in content) | _ = (e for e in content) | ||||
@@ -331,13 +338,13 @@ class Padder: | |||||
:return: np.array([padded_element]) | :return: np.array([padded_element]) | ||||
""" | """ | ||||
def __init__(self, pad_val=0, **kwargs): | def __init__(self, pad_val=0, **kwargs): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
""" | """ | ||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
@@ -396,13 +403,13 @@ class AutoPadder(Padder): | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | ||||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | ||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
""" | """ | ||||
:param pad_val: int, padding的位置使用该index | :param pad_val: int, padding的位置使用该index | ||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
def _is_two_dimension(self, contents): | def _is_two_dimension(self, contents): | ||||
""" | """ | ||||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | ||||
@@ -416,7 +423,7 @@ class AutoPadder(Padder): | |||||
return False | return False | ||||
return True | return True | ||||
return False | return False | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
if not _is_iterable(contents[0]): | if not _is_iterable(contents[0]): | ||||
@@ -458,6 +465,7 @@ class EngChar2DPadder(Padder): | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | ||||
""" | """ | ||||
def __init__(self, pad_val=0, pad_length=0): | def __init__(self, pad_val=0, pad_length=0): | ||||
""" | """ | ||||
:param pad_val: int, pad的位置使用该index | :param pad_val: int, pad的位置使用该index | ||||
@@ -465,9 +473,9 @@ class EngChar2DPadder(Padder): | |||||
都pad或截取到该长度. | 都pad或截取到该长度. | ||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
self.pad_length = pad_length | self.pad_length = pad_length | ||||
def _exactly_three_dims(self, contents, field_name): | def _exactly_three_dims(self, contents, field_name): | ||||
""" | """ | ||||
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | ||||
@@ -486,10 +494,10 @@ class EngChar2DPadder(Padder): | |||||
value = value[0] | value = value[0] | ||||
except: | except: | ||||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | raise ValueError("Field:{} only has two dimensions.".format(field_name)) | ||||
if _is_iterable(value): | if _is_iterable(value): | ||||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
""" | """ | ||||
期望输入类似于 | 期望输入类似于 | ||||
@@ -516,12 +524,12 @@ class EngChar2DPadder(Padder): | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | max_sent_length = max(len(word_lst) for word_lst in contents) | ||||
batch_size = len(contents) | batch_size = len(contents) | ||||
dtype = type(contents[0][0][0]) | dtype = type(contents[0][0][0]) | ||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | ||||
dtype=dtype) | |||||
dtype=dtype) | |||||
for b_idx, word_lst in enumerate(contents): | for b_idx, word_lst in enumerate(contents): | ||||
for c_idx, char_lst in enumerate(word_lst): | for c_idx, char_lst in enumerate(word_lst): | ||||
chars = char_lst[:max_char_length] | chars = char_lst[:max_char_length] | ||||
padded_array[b_idx, c_idx, :len(chars)] = chars | padded_array[b_idx, c_idx, :len(chars)] = chars | ||||
return padded_array | |||||
return padded_array |
@@ -3,7 +3,9 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | ||||
""" | """ | ||||
__all__ = ["Instance"] | |||||
__all__ = [ | |||||
"Instance" | |||||
] | |||||
class Instance(object): | class Instance(object): | ||||
@@ -2,7 +2,18 @@ | |||||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"] | |||||
__all__ = [ | |||||
"LossBase", | |||||
"LossFunc", | |||||
"LossInForward", | |||||
"CrossEntropyLoss", | |||||
"BCELoss", | |||||
"L1Loss", | |||||
"NLLLoss" | |||||
] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -2,6 +2,13 @@ | |||||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"MetricBase", | |||||
"AccuracyMetric", | |||||
"SpanFPreRecMetric", | |||||
"SQuADMetric" | |||||
] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -106,16 +113,17 @@ class MetricBase(object): | |||||
self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} # key is param in function, value is input param. | self.param_map = {} # key is param in function, value is input param. | ||||
self._checked = False | self._checked = False | ||||
def evaluate(self, *args, **kwargs): | def evaluate(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | """检查key_map和其他参数map,并将这些映射关系添加到self.param_map | ||||
@@ -148,7 +156,7 @@ class MetricBase(object): | |||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | # check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
@@ -157,7 +165,7 @@ class MetricBase(object): | |||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
@@ -172,7 +180,7 @@ class MetricBase(object): | |||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
这个方法会调用self.evaluate 方法. | 这个方法会调用self.evaluate 方法. | ||||
@@ -187,12 +195,12 @@ class MetricBase(object): | |||||
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | fast_param = self._fast_param_map(pred_dict, target_dict) | ||||
if fast_param: | if fast_param: | ||||
self.evaluate(**fast_param) | self.evaluate(**fast_param) | ||||
return | return | ||||
if not self._checked: | if not self._checked: | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
@@ -202,14 +210,14 @@ class MetricBase(object): | |||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | ||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg # This param does not need mapping. | self.param_map[arg] = arg # This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | ||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -229,7 +237,7 @@ class MetricBase(object): | |||||
not_duplicate_flag += 1 | not_duplicate_flag += 1 | ||||
if not_duplicate_flag == 3: | if not_duplicate_flag == 3: | ||||
duplicated.append(input_arg) | duplicated.append(input_arg) | ||||
# missing | # missing | ||||
if not self._checked: | if not self._checked: | ||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | ||||
@@ -240,23 +248,23 @@ class MetricBase(object): | |||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | unused=check_res.unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=check_res.required, | required=check_res.required, | ||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.evaluate)) | func_signature=_get_func_signature(self.evaluate)) | ||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
self._checked = True | self._checked = True | ||||
return | return | ||||
@@ -271,15 +279,16 @@ class AccuracyMetric(MetricBase): | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None): | def __init__(self, pred=None, target=None, seq_len=None): | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | """ | ||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -299,16 +308,16 @@ class AccuracyMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | if seq_len is not None and not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if seq_len is not None: | if seq_len is not None: | ||||
masks = seq_len_to_mask(seq_len=seq_len) | masks = seq_len_to_mask(seq_len=seq_len) | ||||
else: | else: | ||||
masks = None | masks = None | ||||
if pred.size() == target.size(): | if pred.size() == target.size(): | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1: | elif len(pred.size()) == len(target.size()) + 1: | ||||
@@ -317,7 +326,7 @@ class AccuracyMetric(MetricBase): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target = target.to(pred) | target = target.to(pred) | ||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | ||||
@@ -325,7 +334,7 @@ class AccuracyMetric(MetricBase): | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | self.acc_count += torch.sum(torch.eq(pred, target)).item() | ||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | """ | ||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
@@ -350,7 +359,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -358,14 +367,14 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
bmes_tag, label = tag[:1], tag[2:] | bmes_tag, label = tag[:1], tag[2:] | ||||
if bmes_tag in ('b', 's'): | if bmes_tag in ('b', 's'): | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | ] | ||||
@@ -379,7 +388,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -387,16 +396,16 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
bmes_tag, label = tag[:1], tag[2:] | bmes_tag, label = tag[:1], tag[2:] | ||||
if bmes_tag in ('b', 's'): | if bmes_tag in ('b', 's'): | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bmes_tag == 'o': | elif bmes_tag == 'o': | ||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | ] | ||||
@@ -410,7 +419,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bio_tag = None | prev_bio_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -418,14 +427,14 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
bio_tag, label = tag[:1], tag[2:] | bio_tag, label = tag[:1], tag[2:] | ||||
if bio_tag == 'b': | if bio_tag == 'b': | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bio_tag == 'o': # o tag does not count | |||||
elif bio_tag == 'o': # o tag does not count | |||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels] | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
@@ -470,16 +479,17 @@ class SpanFPreRecMetric(MetricBase): | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | ||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | ||||
only_gross=True, f_type='micro', beta=1): | |||||
only_gross=True, f_type='micro', beta=1): | |||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
self.encoding_type = encoding_type | self.encoding_type = encoding_type | ||||
if self.encoding_type == 'bmes': | if self.encoding_type == 'bmes': | ||||
self.tag_to_span_func = _bmes_tag_to_spans | self.tag_to_span_func = _bmes_tag_to_spans | ||||
@@ -489,22 +499,22 @@ class SpanFPreRecMetric(MetricBase): | |||||
self.tag_to_span_func = _bmeso_tag_to_spans | self.tag_to_span_func = _bmeso_tag_to_spans | ||||
else: | else: | ||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | ||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
self.beta = beta | self.beta = beta | ||||
self.beta_square = self.beta**2 | |||||
self.beta_square = self.beta ** 2 | |||||
self.only_gross = only_gross | self.only_gross = only_gross | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
def evaluate(self, pred, target, seq_len): | def evaluate(self, pred, target, seq_len): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | """evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -519,11 +529,11 @@ class SpanFPreRecMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if not isinstance(seq_len, torch.Tensor): | if not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
@@ -536,20 +546,20 @@ class SpanFPreRecMetric(MetricBase): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
batch_size = pred.size(0) | batch_size = pred.size(0) | ||||
pred = pred.tolist() | pred = pred.tolist() | ||||
target = target.tolist() | target = target.tolist() | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
pred_tags = pred[i][:int(seq_len[i])] | pred_tags = pred[i][:int(seq_len[i])] | ||||
gold_tags = target[i][:int(seq_len[i])] | gold_tags = target[i][:int(seq_len[i])] | ||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | ||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | ||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | ||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | ||||
for span in pred_spans: | for span in pred_spans: | ||||
if span in gold_spans: | if span in gold_spans: | ||||
self._true_positives[span[0]] += 1 | self._true_positives[span[0]] += 1 | ||||
@@ -558,7 +568,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
self._false_positives[span[0]] += 1 | self._false_positives[span[0]] += 1 | ||||
for span in gold_spans: | for span in gold_spans: | ||||
self._false_negatives[span[0]] += 1 | self._false_negatives[span[0]] += 1 | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | ||||
evaluate_result = {} | evaluate_result = {} | ||||
@@ -577,19 +587,19 @@ class SpanFPreRecMetric(MetricBase): | |||||
f_sum += f | f_sum += f | ||||
pre_sum += pre | pre_sum += pre | ||||
rec_sum + rec | rec_sum + rec | ||||
if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | f_key = 'f-{}'.format(tag) | ||||
pre_key = 'pre-{}'.format(tag) | pre_key = 'pre-{}'.format(tag) | ||||
rec_key = 'rec-{}'.format(tag) | rec_key = 'rec-{}'.format(tag) | ||||
evaluate_result[f_key] = f | evaluate_result[f_key] = f | ||||
evaluate_result[pre_key] = pre | evaluate_result[pre_key] = pre | ||||
evaluate_result[rec_key] = rec | evaluate_result[rec_key] = rec | ||||
if self.f_type == 'macro': | if self.f_type == 'macro': | ||||
evaluate_result['f'] = f_sum/len(tags) | |||||
evaluate_result['pre'] = pre_sum/len(tags) | |||||
evaluate_result['rec'] = rec_sum/len(tags) | |||||
evaluate_result['f'] = f_sum / len(tags) | |||||
evaluate_result['pre'] = pre_sum / len(tags) | |||||
evaluate_result['rec'] = rec_sum / len(tags) | |||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | ||||
sum(self._false_negatives.values()), | sum(self._false_negatives.values()), | ||||
@@ -597,17 +607,17 @@ class SpanFPreRecMetric(MetricBase): | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
if reset: | if reset: | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
for key, value in evaluate_result.items(): | for key, value in evaluate_result.items(): | ||||
evaluate_result[key] = round(value, 6) | evaluate_result[key] = round(value, 6) | ||||
return evaluate_result | return evaluate_result | ||||
def _compute_f_pre_rec(self, tp, fn, fp): | def _compute_f_pre_rec(self, tp, fn, fp): | ||||
""" | """ | ||||
@@ -619,11 +629,10 @@ class SpanFPreRecMetric(MetricBase): | |||||
pre = tp / (fp + tp + 1e-13) | pre = tp / (fp + tp + 1e-13) | ||||
rec = tp / (fn + tp + 1e-13) | rec = tp / (fn + tp + 1e-13) | ||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | ||||
return f, pre, rec | return f, pre, rec | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -705,33 +714,33 @@ class SQuADMetric(MetricBase): | |||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | ||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
super(SQuADMetric, self).__init__() | super(SQuADMetric, self).__init__() | ||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
self.print_predict_stat = print_predict_stat | self.print_predict_stat = print_predict_stat | ||||
self.no_ans_correct = 0 | self.no_ans_correct = 0 | ||||
self.no_ans_wrong = 0 | self.no_ans_wrong = 0 | ||||
self.has_ans_correct = 0 | self.has_ans_correct = 0 | ||||
self.has_ans_wrong = 0 | self.has_ans_wrong = 0 | ||||
self.has_ans_f = 0. | self.has_ans_f = 0. | ||||
self.no2no = 0 | self.no2no = 0 | ||||
self.no2yes = 0 | self.no2yes = 0 | ||||
self.yes2no = 0 | self.yes2no = 0 | ||||
self.yes2yes = 0 | self.yes2yes = 0 | ||||
self.f_beta = beta | self.f_beta = beta | ||||
self.right_open = right_open | self.right_open = right_open | ||||
def evaluate(self, pred1, pred2, target1, target2): | def evaluate(self, pred1, pred2, target1, target2): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | """evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -745,7 +754,7 @@ class SQuADMetric(MetricBase): | |||||
pred_end = pred2 | pred_end = pred2 | ||||
target_start = target1 | target_start = target1 | ||||
target_end = target2 | target_end = target2 | ||||
if len(pred_start.size()) == 2: | if len(pred_start.size()) == 2: | ||||
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | ||||
else: | else: | ||||
@@ -754,12 +763,12 @@ class SQuADMetric(MetricBase): | |||||
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | ||||
else: | else: | ||||
end_inference = pred_end.cpu().tolist() | end_inference = pred_end.cpu().tolist() | ||||
start, end = [], [] | start, end = [], [] | ||||
max_len = pred_start.size(1) | max_len = pred_start.size(1) | ||||
t_start = target_start.cpu().tolist() | t_start = target_start.cpu().tolist() | ||||
t_end = target_end.cpu().tolist() | t_end = target_end.cpu().tolist() | ||||
for s, e in zip(start_inference, end_inference): | for s, e in zip(start_inference, end_inference): | ||||
start.append(min(s, e)) | start.append(min(s, e)) | ||||
end.append(max(s, e)) | end.append(max(s, e)) | ||||
@@ -779,7 +788,7 @@ class SQuADMetric(MetricBase): | |||||
self.yes2no += 1 | self.yes2no += 1 | ||||
else: | else: | ||||
self.yes2yes += 1 | self.yes2yes += 1 | ||||
if s == ts and e == te: | if s == ts and e == te: | ||||
self.has_ans_correct += 1 | self.has_ans_correct += 1 | ||||
else: | else: | ||||
@@ -787,29 +796,29 @@ class SQuADMetric(MetricBase): | |||||
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | ||||
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | ||||
a, b = torch.tensor(a), torch.tensor(b) | a, b = torch.tensor(a), torch.tensor(b) | ||||
TP = int(torch.sum(a * b)) | TP = int(torch.sum(a * b)) | ||||
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | ||||
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | ||||
if pre + rec > 0: | if pre + rec > 0: | ||||
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) | |||||
f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec) | |||||
else: | else: | ||||
f = 0 | f = 0 | ||||
self.has_ans_f += f | self.has_ans_f += f | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | ||||
evaluate_result = {} | evaluate_result = {} | ||||
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | ||||
return evaluate_result | return evaluate_result | ||||
evaluate_result['EM'] = 0 | evaluate_result['EM'] = 0 | ||||
evaluate_result[f'f_{self.f_beta}'] = 0 | evaluate_result[f'f_{self.f_beta}'] = 0 | ||||
flag = 0 | flag = 0 | ||||
if self.no_ans_correct + self.no_ans_wrong > 0: | if self.no_ans_correct + self.no_ans_wrong > 0: | ||||
evaluate_result[f'noAns-f_{self.f_beta}'] = \ | evaluate_result[f'noAns-f_{self.f_beta}'] = \ | ||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | ||||
@@ -818,7 +827,7 @@ class SQuADMetric(MetricBase): | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | ||||
evaluate_result['EM'] += evaluate_result['noAns-EM'] | evaluate_result['EM'] += evaluate_result['noAns-EM'] | ||||
flag += 1 | flag += 1 | ||||
if self.has_ans_correct + self.has_ans_wrong > 0: | if self.has_ans_correct + self.has_ans_wrong > 0: | ||||
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | ||||
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | ||||
@@ -827,32 +836,31 @@ class SQuADMetric(MetricBase): | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | ||||
evaluate_result['EM'] += evaluate_result['hasAns-EM'] | evaluate_result['EM'] += evaluate_result['hasAns-EM'] | ||||
flag += 1 | flag += 1 | ||||
if self.print_predict_stat: | if self.print_predict_stat: | ||||
evaluate_result['no2no'] = self.no2no | evaluate_result['no2no'] = self.no2no | ||||
evaluate_result['no2yes'] = self.no2yes | evaluate_result['no2yes'] = self.no2yes | ||||
evaluate_result['yes2no'] = self.yes2no | evaluate_result['yes2no'] = self.yes2no | ||||
evaluate_result['yes2yes'] = self.yes2yes | evaluate_result['yes2yes'] = self.yes2yes | ||||
if flag <= 0: | if flag <= 0: | ||||
return evaluate_result | return evaluate_result | ||||
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | ||||
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | ||||
if reset: | if reset: | ||||
self.no_ans_correct = 0 | self.no_ans_correct = 0 | ||||
self.no_ans_wrong = 0 | self.no_ans_wrong = 0 | ||||
self.has_ans_correct = 0 | self.has_ans_correct = 0 | ||||
self.has_ans_wrong = 0 | self.has_ans_wrong = 0 | ||||
self.has_ans_f = 0. | self.has_ans_f = 0. | ||||
self.no2no = 0 | self.no2no = 0 | ||||
self.no2yes = 0 | self.no2yes = 0 | ||||
self.yes2no = 0 | self.yes2no = 0 | ||||
self.yes2yes = 0 | self.yes2yes = 0 | ||||
return evaluate_result | return evaluate_result | ||||
@@ -2,6 +2,12 @@ | |||||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"Optimizer", | |||||
"SGD", | |||||
"Adam" | |||||
] | |||||
import torch | import torch | ||||
@@ -12,15 +18,16 @@ class Optimizer(object): | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
:param kwargs: additional parameters. | :param kwargs: additional parameters. | ||||
""" | """ | ||||
def __init__(self, model_params, **kwargs): | def __init__(self, model_params, **kwargs): | ||||
if model_params is not None and not hasattr(model_params, "__next__"): | if model_params is not None and not hasattr(model_params, "__next__"): | ||||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | ||||
self.model_params = model_params | self.model_params = model_params | ||||
self.settings = kwargs | self.settings = kwargs | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _get_require_grads_param(self, params): | def _get_require_grads_param(self, params): | ||||
""" | """ | ||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
@@ -29,6 +36,7 @@ class Optimizer(object): | |||||
""" | """ | ||||
return [param for param in params if param.requires_grad] | return [param for param in params if param.requires_grad] | ||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
""" | """ | ||||
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | 别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | ||||
@@ -37,12 +45,12 @@ class SGD(Optimizer): | |||||
:param float momentum: momentum. Default: 0 | :param float momentum: momentum. Default: 0 | ||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
""" | """ | ||||
def __init__(self, lr=0.001, momentum=0, model_params=None): | def __init__(self, lr=0.001, momentum=0, model_params=None): | ||||
if not isinstance(lr, float): | if not isinstance(lr, float): | ||||
raise TypeError("learning rate has to be float.") | raise TypeError("learning rate has to be float.") | ||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
if self.model_params is None: | if self.model_params is None: | ||||
# careful! generator cannot be assigned. | # careful! generator cannot be assigned. | ||||
@@ -59,13 +67,13 @@ class Adam(Optimizer): | |||||
:param float weight_decay: | :param float weight_decay: | ||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
""" | """ | ||||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | ||||
if not isinstance(lr, float): | if not isinstance(lr, float): | ||||
raise TypeError("learning rate has to be float.") | raise TypeError("learning rate has to be float.") | ||||
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | ||||
weight_decay=weight_decay) | weight_decay=weight_decay) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
if self.model_params is None: | if self.model_params is None: | ||||
# careful! generator cannot be assigned. | # careful! generator cannot be assigned. | ||||
@@ -1,3 +1,7 @@ | |||||
""" | |||||
..todo:: | |||||
检查这个类是否需要 | |||||
""" | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
@@ -9,7 +13,8 @@ from .utils import _build_args | |||||
class Predictor(object): | class Predictor(object): | ||||
"""An interface for predicting outputs based on trained models. | |||||
""" | |||||
An interface for predicting outputs based on trained models. | |||||
It does not care about evaluations of the model, which is different from Tester. | It does not care about evaluations of the model, which is different from Tester. | ||||
This is a high-level model wrapper to be called by FastNLP. | This is a high-level model wrapper to be called by FastNLP. | ||||
@@ -1,9 +1,13 @@ | |||||
""" | """ | ||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | sampler 子类实现了 fastNLP 所需的各种采样器。 | ||||
""" | """ | ||||
__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"] | |||||
__all__ = [ | |||||
"Sampler", | |||||
"BucketSampler", | |||||
"SequentialSampler", | |||||
"RandomSampler" | |||||
] | |||||
from itertools import chain | from itertools import chain | ||||
import numpy as np | import numpy as np | ||||
@@ -35,7 +35,7 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||||
import warnings | import warnings | ||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
from .batch import Batch | from .batch import Batch | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
@@ -49,6 +49,10 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
__all__ = [ | |||||
"Tester" | |||||
] | |||||
class Tester(object): | class Tester(object): | ||||
""" | """ | ||||
@@ -77,29 +81,29 @@ class Tester(object): | |||||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | ||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | ||||
self.metrics = _prepare_metrics(metrics) | self.metrics = _prepare_metrics(metrics) | ||||
self.data = data | self.data = data | ||||
self._model = _move_model_to_device(model, device=device) | self._model = _move_model_to_device(model, device=device) | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
# 如果是DataParallel将没有办法使用predict方法 | # 如果是DataParallel将没有办法使用predict方法 | ||||
if isinstance(self._model, nn.DataParallel): | if isinstance(self._model, nn.DataParallel): | ||||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | ||||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | ||||
" while DataParallel has no predict() function.") | " while DataParallel has no predict() function.") | ||||
self._model = self._model.module | self._model = self._model.module | ||||
# check predict | # check predict | ||||
if hasattr(self._model, 'predict'): | if hasattr(self._model, 'predict'): | ||||
self._predict_func = self._model.predict | self._predict_func = self._model.predict | ||||
@@ -109,7 +113,7 @@ class Tester(object): | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | f"for evaluation, not `{type(self._predict_func)}`.") | ||||
else: | else: | ||||
self._predict_func = self._model.forward | self._predict_func = self._model.forward | ||||
def test(self): | def test(self): | ||||
"""开始进行验证,并返回验证结果。 | """开始进行验证,并返回验证结果。 | ||||
@@ -144,12 +148,12 @@ class Tester(object): | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | ||||
dataset=self.data, check_level=0) | dataset=self.data, check_level=0) | ||||
if self.verbose >= 1: | if self.verbose >= 1: | ||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | print("[tester] \n{}".format(self._format_eval_results(eval_results))) | ||||
self._mode(network, is_test=False) | self._mode(network, is_test=False) | ||||
return eval_results | return eval_results | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -161,13 +165,13 @@ class Tester(object): | |||||
model.eval() | model.eval() | ||||
else: | else: | ||||
model.train() | model.train() | ||||
def _data_forward(self, func, x): | def _data_forward(self, func, x): | ||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
x = _build_args(func, **x) | x = _build_args(func, **x) | ||||
y = func(**x) | y = func(**x) | ||||
return y | return y | ||||
def _format_eval_results(self, results): | def _format_eval_results(self, results): | ||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
@@ -295,15 +295,17 @@ Example2.3 | |||||
fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"Trainer" | |||||
] | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
from datetime import datetime, timedelta | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -315,6 +317,7 @@ from .callback import CallbackManager, CallbackException | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .metrics import _prepare_metrics | from .metrics import _prepare_metrics | ||||
from .optimizer import Optimizer | |||||
from .sampler import Sampler | from .sampler import Sampler | ||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
@@ -326,7 +329,6 @@ from .utils import _check_loss_evaluate | |||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .optimizer import Optimizer | |||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
@@ -464,7 +466,7 @@ class Trainer(object): | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
self.model = _move_model_to_device(self.model, device=device) | self.model = _move_model_to_device(self.model, device=device) | ||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
@@ -1,20 +1,25 @@ | |||||
""" | """ | ||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | ||||
""" | """ | ||||
__all__ = ["cache_results", "seq_len_to_mask"] | |||||
__all__ = [ | |||||
"cache_results", | |||||
"seq_len_to_mask" | |||||
] | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import os | import os | ||||
import warnings | import warnings | ||||
from collections import Counter | |||||
from collections import namedtuple | |||||
from collections import Counter, namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | |||||
'varargs']) | |||||
def _prepare_cache_filepath(filepath): | def _prepare_cache_filepath(filepath): | ||||
""" | """ | ||||
@@ -40,26 +45,28 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
import time | import time | ||||
import numpy as np | import numpy as np | ||||
from fastNLP import cache_results | from fastNLP import cache_results | ||||
@cache_results('cache.pkl') | @cache_results('cache.pkl') | ||||
def process_data(): | def process_data(): | ||||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | ||||
time.sleep(1) | time.sleep(1) | ||||
return np.random.randint(5, size=(10, 20)) | |||||
return np.random.randint(10, size=(5,)) | |||||
start_time = time.time() | start_time = time.time() | ||||
process_data() | |||||
print("res =",process_data()) | |||||
print(time.time() - start_time) | print(time.time() - start_time) | ||||
start_time = time.time() | start_time = time.time() | ||||
process_data() | |||||
print("res =",process_data()) | |||||
print(time.time() - start_time) | print(time.time() - start_time) | ||||
# 输出内容如下 | |||||
# Save cache to cache.pkl. | |||||
# 1.0015439987182617 | |||||
# Read cache from cache.pkl. | |||||
# 0.00013065338134765625 | |||||
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||||
# Save cache to cache.pkl. | |||||
# res = [5 4 9 1 8] | |||||
# 1.0042750835418701 | |||||
# Read cache from cache.pkl. | |||||
# res = [5 4 9 1 8] | |||||
# 0.0040721893310546875 | |||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理 | 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理 | ||||
@@ -83,11 +90,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
:param int _verbose: 是否打印cache的信息。 | :param int _verbose: 是否打印cache的信息。 | ||||
:return: | :return: | ||||
""" | """ | ||||
def wrapper_(func): | def wrapper_(func): | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
for key, _ in signature.parameters.items(): | for key, _ in signature.parameters.items(): | ||||
if key in ('_cache_fp', '_refresh', '_verbose'): | if key in ('_cache_fp', '_refresh', '_verbose'): | ||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
if '_cache_fp' in kwargs: | if '_cache_fp' in kwargs: | ||||
cache_filepath = kwargs.pop('_cache_fp') | cache_filepath = kwargs.pop('_cache_fp') | ||||
@@ -95,7 +104,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
else: | else: | ||||
cache_filepath = _cache_fp | cache_filepath = _cache_fp | ||||
if '_refresh' in kwargs: | if '_refresh' in kwargs: | ||||
refresh = kwargs.pop('_refresh') | |||||
refresh = kwargs.pop('_refresh') | |||||
assert isinstance(refresh, bool), "_refresh can only be bool." | assert isinstance(refresh, bool), "_refresh can only be bool." | ||||
else: | else: | ||||
refresh = _refresh | refresh = _refresh | ||||
@@ -105,16 +114,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
else: | else: | ||||
verbose = _verbose | verbose = _verbose | ||||
refresh_flag = True | refresh_flag = True | ||||
if cache_filepath is not None and refresh is False: | if cache_filepath is not None and refresh is False: | ||||
# load data | # load data | ||||
if os.path.exists(cache_filepath): | if os.path.exists(cache_filepath): | ||||
with open(cache_filepath, 'rb') as f: | with open(cache_filepath, 'rb') as f: | ||||
results = _pickle.load(f) | results = _pickle.load(f) | ||||
if verbose==1: | |||||
if verbose == 1: | |||||
print("Read cache from {}.".format(cache_filepath)) | print("Read cache from {}.".format(cache_filepath)) | ||||
refresh_flag = False | refresh_flag = False | ||||
if refresh_flag: | if refresh_flag: | ||||
results = func(*args, **kwargs) | results = func(*args, **kwargs) | ||||
if cache_filepath is not None: | if cache_filepath is not None: | ||||
@@ -124,11 +133,14 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
with open(cache_filepath, 'wb') as f: | with open(cache_filepath, 'wb') as f: | ||||
_pickle.dump(results, f) | _pickle.dump(results, f) | ||||
print("Save cache to {}.".format(cache_filepath)) | print("Save cache to {}.".format(cache_filepath)) | ||||
return results | return results | ||||
return wrapper | return wrapper | ||||
return wrapper_ | return wrapper_ | ||||
# def save_pickle(obj, pickle_path, file_name): | # def save_pickle(obj, pickle_path, file_name): | ||||
# """Save an object into a pickle file. | # """Save an object into a pickle file. | ||||
# | # | ||||
@@ -196,7 +208,7 @@ def _move_model_to_device(model, device): | |||||
""" | """ | ||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||||
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | ||||
if device is None: | if device is None: | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
model.cuda() | model.cuda() | ||||
@@ -205,34 +217,35 @@ def _move_model_to_device(model, device): | |||||
if not torch.cuda.is_available() and ( | if not torch.cuda.is_available() and ( | ||||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | ||||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | ||||
if isinstance(device, int): | if isinstance(device, int): | ||||
assert device>-1, "device can only be non-negative integer" | |||||
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), | |||||
device) | |||||
assert device > -1, "device can only be non-negative integer" | |||||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
device = torch.device('cuda:{}'.format(device)) | device = torch.device('cuda:{}'.format(device)) | ||||
elif isinstance(device, str): | elif isinstance(device, str): | ||||
device = torch.device(device) | device = torch.device(device) | ||||
if device.type == 'cuda' and device.index is not None: | if device.type == 'cuda' and device.index is not None: | ||||
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
elif isinstance(device, torch.device): | elif isinstance(device, torch.device): | ||||
if device.type == 'cuda' and device.index is not None: | if device.type == 'cuda' and device.index is not None: | ||||
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
elif isinstance(device, list): | elif isinstance(device, list): | ||||
types = set([type(d) for d in device]) | types = set([type(d) for d in device]) | ||||
assert len(types)==1, "Mixed type in device, only `int` allowed." | |||||
assert len(types) == 1, "Mixed type in device, only `int` allowed." | |||||
assert list(types)[0] == int, "Only int supported for multiple devices." | assert list(types)[0] == int, "Only int supported for multiple devices." | ||||
assert len(set(device))==len(device), "Duplicated device id found in device." | |||||
assert len(set(device)) == len(device), "Duplicated device id found in device." | |||||
for d in device: | for d in device: | ||||
assert d>-1, "Only non-negative device id allowed." | |||||
if len(device)>1: | |||||
assert d > -1, "Only non-negative device id allowed." | |||||
if len(device) > 1: | |||||
output_device = device[0] | output_device = device[0] | ||||
model = nn.DataParallel(model, device_ids=device, output_device=output_device) | model = nn.DataParallel(model, device_ids=device, output_device=output_device) | ||||
device = torch.device(device[0]) | device = torch.device(device[0]) | ||||
@@ -250,9 +263,9 @@ def _get_model_device(model): | |||||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | ||||
""" | """ | ||||
assert isinstance(model, nn.Module) | assert isinstance(model, nn.Module) | ||||
parameters = list(model.parameters()) | parameters = list(model.parameters()) | ||||
if len(parameters)==0: | |||||
if len(parameters) == 0: | |||||
return None | return None | ||||
else: | else: | ||||
return parameters[0].device | return parameters[0].device | ||||
@@ -407,7 +420,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
if not isinstance(device, torch.device): | if not isinstance(device, torch.device): | ||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | ||||
for arg in args: | for arg in args: | ||||
if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
for key, value in arg.items(): | for key, value in arg.items(): | ||||
@@ -422,10 +435,10 @@ class _CheckError(Exception): | |||||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | _CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res: _CheckRes, func_signature: str): | def __init__(self, check_res: _CheckRes, func_signature: str): | ||||
errs = [f'Problems occurred when calling `{func_signature}`'] | errs = [f'Problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
if check_res.missing: | if check_res.missing: | ||||
@@ -434,9 +447,9 @@ class _CheckError(Exception): | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | errs.append(f"\tduplicated param: {check_res.duplicated}") | ||||
if check_res.unused: | if check_res.unused: | ||||
errs.append(f"\tunused param: {check_res.unused}") | errs.append(f"\tunused param: {check_res.unused}") | ||||
Exception.__init__(self, '\n'.join(errs)) | Exception.__init__(self, '\n'.join(errs)) | ||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature | self.func_signature = func_signature | ||||
@@ -456,7 +469,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: *{check_res.varargs}") | # errs.append(f"\tvarargs: *{check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
if check_res.unused: | if check_res.unused: | ||||
for _unused in check_res.unused: | for _unused in check_res.unused: | ||||
if _unused in target_dict: | if _unused in target_dict: | ||||
@@ -466,8 +479,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
if _unused_field: | if _unused_field: | ||||
unuseds.append(f"\tunused field: {_unused_field}") | unuseds.append(f"\tunused field: {_unused_field}") | ||||
if _unused_param: | if _unused_param: | ||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
module_name = func_signature.split('.')[0] | module_name = func_signature.split('.')[0] | ||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
@@ -488,14 +501,14 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
mapped_missing.append(_miss) | mapped_missing.append(_miss) | ||||
else: | else: | ||||
unmapped_missing.append(_miss) | unmapped_missing.append(_miss) | ||||
for _miss in mapped_missing + unmapped_missing: | for _miss in mapped_missing + unmapped_missing: | ||||
if _miss in dataset: | if _miss in dataset: | ||||
suggestions.append(f"Set `{_miss}` as target.") | suggestions.append(f"Set `{_miss}` as target.") | ||||
else: | else: | ||||
_tmp = '' | _tmp = '' | ||||
if check_res.unused: | if check_res.unused: | ||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." | |||||
if _tmp: | if _tmp: | ||||
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | ||||
else: | else: | ||||
@@ -513,25 +526,25 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
# else: | # else: | ||||
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | ||||
# suggestions.append(_tmp) | # suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
if idx>0: | |||||
if idx > 0: | |||||
sugg_str += '\t\t\t' | sugg_str += '\t\t\t' | ||||
sugg_str += f'({idx+1}). {sugg}\n' | |||||
sugg_str += f'({idx + 1}). {sugg}\n' | |||||
sugg_str = sugg_str[:-1] | sugg_str = sugg_str[:-1] | ||||
else: | else: | ||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
@@ -546,14 +559,15 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
_unused_warn = f'{check_res.unused} is not used by {module_name}.' | _unused_warn = f'{check_res.unused} is not used by {module_name}.' | ||||
warnings.warn(message=_unused_warn) | warnings.warn(message=_unused_warn) | ||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | def _check_forward_error(forward_func, batch_x, dataset, check_level): | ||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = _get_func_signature(forward_func) | func_signature = _get_func_signature(forward_func) | ||||
errs = [] | errs = [] | ||||
suggestions = [] | suggestions = [] | ||||
_unused = [] | _unused = [] | ||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: {check_res.varargs}") | # errs.append(f"\tvarargs: {check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
@@ -574,20 +588,20 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | ||||
# f"rename the field in `unused field:`." | # f"rename the field in `unused field:`." | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
_unused = [f"\tunused field: {check_res.unused}"] | _unused = [f"\tunused field: {check_res.unused}"] | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.extend(_unused) | errs.extend(_unused) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | |||||
sugg_str += f'({idx + 1}). {sugg}' | |||||
else: | else: | ||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | ||||
@@ -622,8 +636,8 @@ def seq_len_to_mask(seq_len): | |||||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | ||||
max_len = int(seq_len.max()) | max_len = int(seq_len.max()) | ||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | ||||
mask = broad_cast_seq_len<seq_len.reshape(-1, 1) | |||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||||
elif isinstance(seq_len, torch.Tensor): | elif isinstance(seq_len, torch.Tensor): | ||||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | ||||
batch_size = seq_len.size(0) | batch_size = seq_len.size(0) | ||||
@@ -632,7 +646,7 @@ def seq_len_to_mask(seq_len): | |||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | ||||
else: | else: | ||||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | ||||
return mask | return mask | ||||
@@ -640,24 +654,24 @@ class _pseudo_tqdm: | |||||
""" | """ | ||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
pass | pass | ||||
def write(self, info): | def write(self, info): | ||||
print(info) | print(info) | ||||
def set_postfix_str(self, info): | def set_postfix_str(self, info): | ||||
print(info) | print(info) | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
pass | pass | ||||
return pass_func | return pass_func | ||||
def __enter__(self): | def __enter__(self): | ||||
return self | return self | ||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
del self | del self |
@@ -1,5 +1,10 @@ | |||||
__all__ = [ | |||||
"Vocabulary" | |||||
] | |||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
@@ -318,6 +323,17 @@ class Vocabulary(object): | |||||
""" | """ | ||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
def clear(self): | |||||
""" | |||||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | |||||
:return: | |||||
""" | |||||
self.word_count.clear() | |||||
self.word2idx = None | |||||
self.idx2word = None | |||||
self.rebuild = True | |||||
def __getstate__(self): | def __getstate__(self): | ||||
"""Use to prepare data for pickle. | """Use to prepare data for pickle. | ||||
@@ -24,7 +24,8 @@ __all__ = [ | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | ||||
PeopleDailyCorpusLoader, Conll2003Loader | PeopleDailyCorpusLoader, Conll2003Loader | ||||
from .model_io import ModelLoader as ModelLoader, ModelSaver as ModelSaver | |||||
from .model_io import ModelLoader, ModelSaver |
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"BaseLoader" | |||||
] | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import os | import os | ||||
@@ -7,9 +11,10 @@ class BaseLoader(object): | |||||
各个 Loader 的基类,提供了 API 的参考。 | 各个 Loader 的基类,提供了 API 的参考。 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def load_lines(data_path): | def load_lines(data_path): | ||||
""" | """ | ||||
@@ -20,7 +25,7 @@ class BaseLoader(object): | |||||
with open(data_path, "r", encoding="utf=8") as f: | with open(data_path, "r", encoding="utf=8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
@classmethod | @classmethod | ||||
def load(cls, data_path): | def load(cls, data_path): | ||||
""" | """ | ||||
@@ -31,7 +36,7 @@ class BaseLoader(object): | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [[word for word in sent.strip()] for sent in text] | return [[word for word in sent.strip()] for sent in text] | ||||
@classmethod | @classmethod | ||||
def load_with_cache(cls, data_path, cache_path): | def load_with_cache(cls, data_path, cache_path): | ||||
"""缓存版的load | """缓存版的load | ||||
@@ -48,16 +53,18 @@ class BaseLoader(object): | |||||
class DataLoaderRegister: | class DataLoaderRegister: | ||||
_readers = {} | _readers = {} | ||||
@classmethod | @classmethod | ||||
def set_reader(cls, reader_cls, read_fn_name): | def set_reader(cls, reader_cls, read_fn_name): | ||||
# def wrapper(reader_cls): | # def wrapper(reader_cls): | ||||
if read_fn_name in cls._readers: | if read_fn_name in cls._readers: | ||||
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||||
raise KeyError( | |||||
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, | |||||
read_fn_name)) | |||||
if hasattr(reader_cls, 'load'): | if hasattr(reader_cls, 'load'): | ||||
cls._readers[read_fn_name] = reader_cls().load | cls._readers[read_fn_name] = reader_cls().load | ||||
return reader_cls | return reader_cls | ||||
@classmethod | @classmethod | ||||
def get_reader(cls, read_fn_name): | def get_reader(cls, read_fn_name): | ||||
if read_fn_name in cls._readers: | if read_fn_name in cls._readers: | ||||
@@ -1,8 +1,14 @@ | |||||
""" | """ | ||||
用于读入和处理和保存 config 文件 | 用于读入和处理和保存 config 文件 | ||||
.. todo:: | |||||
这个模块中的类可能被抛弃? | |||||
""" | """ | ||||
__all__ = ["ConfigLoader","ConfigSection","ConfigSaver"] | |||||
__all__ = [ | |||||
"ConfigLoader", | |||||
"ConfigSection", | |||||
"ConfigSaver" | |||||
] | |||||
import configparser | import configparser | ||||
import json | import json | ||||
import os | import os | ||||
@@ -19,15 +25,16 @@ class ConfigLoader(BaseLoader): | |||||
:param str data_path: 配置文件的路径 | :param str data_path: 配置文件的路径 | ||||
""" | """ | ||||
def __init__(self, data_path=None): | def __init__(self, data_path=None): | ||||
super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
if data_path is not None: | if data_path is not None: | ||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | self.config = self.parse(super(ConfigLoader, self).load(data_path)) | ||||
@staticmethod | @staticmethod | ||||
def parse(string): | def parse(string): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@staticmethod | @staticmethod | ||||
def load_config(file_path, sections): | def load_config(file_path, sections): | ||||
""" | """ | ||||
@@ -81,10 +88,10 @@ class ConfigSection(object): | |||||
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(ConfigSection, self).__init__() | super(ConfigSection, self).__init__() | ||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
""" | """ | ||||
:param key: str, the name of the attribute | :param key: str, the name of the attribute | ||||
@@ -97,7 +104,7 @@ class ConfigSection(object): | |||||
if key in self.__dict__.keys(): | if key in self.__dict__.keys(): | ||||
return getattr(self, key) | return getattr(self, key) | ||||
raise AttributeError("do NOT have attribute %s" % key) | raise AttributeError("do NOT have attribute %s" % key) | ||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
""" | """ | ||||
:param key: str, the name of the attribute | :param key: str, the name of the attribute | ||||
@@ -112,14 +119,14 @@ class ConfigSection(object): | |||||
raise AttributeError("attr %s except %s but got %s" % | raise AttributeError("attr %s except %s but got %s" % | ||||
(key, str(type(getattr(self, key))), str(type(value)))) | (key, str(type(getattr(self, key))), str(type(value)))) | ||||
setattr(self, key, value) | setattr(self, key, value) | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
""" | """ | ||||
:param item: The key of item. | :param item: The key of item. | ||||
:return: True if the key in self.__dict__.keys() else False. | :return: True if the key in self.__dict__.keys() else False. | ||||
""" | """ | ||||
return item in self.__dict__.keys() | return item in self.__dict__.keys() | ||||
def __eq__(self, other): | def __eq__(self, other): | ||||
"""Overwrite the == operator | """Overwrite the == operator | ||||
@@ -131,15 +138,15 @@ class ConfigSection(object): | |||||
return False | return False | ||||
if getattr(self, k) != getattr(self, k): | if getattr(self, k) != getattr(self, k): | ||||
return False | return False | ||||
for k in other.__dict__.keys(): | for k in other.__dict__.keys(): | ||||
if k not in self.__dict__.keys(): | if k not in self.__dict__.keys(): | ||||
return False | return False | ||||
if getattr(self, k) != getattr(self, k): | if getattr(self, k) != getattr(self, k): | ||||
return False | return False | ||||
return True | return True | ||||
def __ne__(self, other): | def __ne__(self, other): | ||||
"""Overwrite the != operator | """Overwrite the != operator | ||||
@@ -147,7 +154,7 @@ class ConfigSection(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
return not self.__eq__(other) | return not self.__eq__(other) | ||||
@property | @property | ||||
def data(self): | def data(self): | ||||
return self.__dict__ | return self.__dict__ | ||||
@@ -162,11 +169,12 @@ class ConfigSaver(object): | |||||
:param str file_path: 配置文件的路径 | :param str file_path: 配置文件的路径 | ||||
""" | """ | ||||
def __init__(self, file_path): | def __init__(self, file_path): | ||||
self.file_path = file_path | self.file_path = file_path | ||||
if not os.path.exists(self.file_path): | if not os.path.exists(self.file_path): | ||||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | ||||
def _get_section(self, sect_name): | def _get_section(self, sect_name): | ||||
""" | """ | ||||
This is the function to get the section with the section name. | This is the function to get the section with the section name. | ||||
@@ -177,7 +185,7 @@ class ConfigSaver(object): | |||||
sect = ConfigSection() | sect = ConfigSection() | ||||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | ConfigLoader().load_config(self.file_path, {sect_name: sect}) | ||||
return sect | return sect | ||||
def _read_section(self): | def _read_section(self): | ||||
""" | """ | ||||
This is the function to read sections from the config file. | This is the function to read sections from the config file. | ||||
@@ -187,16 +195,16 @@ class ConfigSaver(object): | |||||
sect_key_list: A list of names in sect_list. | sect_key_list: A list of names in sect_list. | ||||
""" | """ | ||||
sect_name = None | sect_name = None | ||||
sect_list = {} | sect_list = {} | ||||
sect_key_list = [] | sect_key_list = [] | ||||
single_section = {} | single_section = {} | ||||
single_section_key = [] | single_section_key = [] | ||||
with open(self.file_path, 'r') as f: | with open(self.file_path, 'r') as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
for line in lines: | for line in lines: | ||||
if line.startswith('[') and line.endswith(']\n'): | if line.startswith('[') and line.endswith(']\n'): | ||||
if sect_name is None: | if sect_name is None: | ||||
@@ -208,29 +216,29 @@ class ConfigSaver(object): | |||||
sect_key_list.append(sect_name) | sect_key_list.append(sect_name) | ||||
sect_name = line[1: -2] | sect_name = line[1: -2] | ||||
continue | continue | ||||
if line.startswith('#'): | if line.startswith('#'): | ||||
single_section[line] = '#' | single_section[line] = '#' | ||||
single_section_key.append(line) | single_section_key.append(line) | ||||
continue | continue | ||||
if line.startswith('\n'): | if line.startswith('\n'): | ||||
single_section_key.append('\n') | single_section_key.append('\n') | ||||
continue | continue | ||||
if '=' not in line: | if '=' not in line: | ||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | ||||
key = line.split('=', maxsplit=1)[0].strip() | key = line.split('=', maxsplit=1)[0].strip() | ||||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | value = line.split('=', maxsplit=1)[1].strip() + '\n' | ||||
single_section[key] = value | single_section[key] = value | ||||
single_section_key.append(key) | single_section_key.append(key) | ||||
if sect_name is not None: | if sect_name is not None: | ||||
sect_list[sect_name] = single_section, single_section_key | sect_list[sect_name] = single_section, single_section_key | ||||
sect_key_list.append(sect_name) | sect_key_list.append(sect_name) | ||||
return sect_list, sect_key_list | return sect_list, sect_key_list | ||||
def _write_section(self, sect_list, sect_key_list): | def _write_section(self, sect_list, sect_key_list): | ||||
""" | """ | ||||
This is the function to write config file with section list and name list. | This is the function to write config file with section list and name list. | ||||
@@ -252,7 +260,7 @@ class ConfigSaver(object): | |||||
continue | continue | ||||
f.write(key + ' = ' + single_section[key]) | f.write(key + ' = ' + single_section[key]) | ||||
f.write('\n') | f.write('\n') | ||||
def save_config_file(self, section_name, section): | def save_config_file(self, section_name, section): | ||||
""" | """ | ||||
这个方法可以用来修改并保存配置文件中单独的一个 section | 这个方法可以用来修改并保存配置文件中单独的一个 section | ||||
@@ -284,11 +292,11 @@ class ConfigSaver(object): | |||||
break | break | ||||
if not change_file: | if not change_file: | ||||
return | return | ||||
sect_list, sect_key_list = self._read_section() | sect_list, sect_key_list = self._read_section() | ||||
if section_name not in sect_key_list: | if section_name not in sect_key_list: | ||||
raise AttributeError() | raise AttributeError() | ||||
sect, sect_key = sect_list[section_name] | sect, sect_key = sect_list[section_name] | ||||
for k in section.__dict__.keys(): | for k in section.__dict__.keys(): | ||||
if k not in sect_key: | if k not in sect_key: | ||||
@@ -20,6 +20,7 @@ __all__ = [ | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
from nltk.tree import Tree | from nltk.tree import Tree | ||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
@@ -1,11 +1,15 @@ | |||||
__all__ = [ | |||||
"EmbedLoader" | |||||
] | |||||
import os | import os | ||||
import warnings | |||||
import numpy as np | import numpy as np | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .base_loader import BaseLoader | from .base_loader import BaseLoader | ||||
import warnings | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
""" | """ | ||||
@@ -13,10 +17,10 @@ class EmbedLoader(BaseLoader): | |||||
用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | 用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(EmbedLoader, self).__init__() | super(EmbedLoader, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | ||||
""" | """ | ||||
@@ -40,11 +44,11 @@ class EmbedLoader(BaseLoader): | |||||
line = f.readline().strip() | line = f.readline().strip() | ||||
parts = line.split() | parts = line.split() | ||||
start_idx = 0 | start_idx = 0 | ||||
if len(parts)==2: | |||||
if len(parts) == 2: | |||||
dim = int(parts[1]) | dim = int(parts[1]) | ||||
start_idx += 1 | start_idx += 1 | ||||
else: | else: | ||||
dim = len(parts)-1 | |||||
dim = len(parts) - 1 | |||||
f.seek(0) | f.seek(0) | ||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | matrix = np.random.randn(len(vocab), dim).astype(dtype) | ||||
for idx, line in enumerate(f, start_idx): | for idx, line in enumerate(f, start_idx): | ||||
@@ -63,21 +67,21 @@ class EmbedLoader(BaseLoader): | |||||
total_hits = sum(hit_flags) | total_hits = sum(hit_flags) | ||||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | ||||
found_vectors = matrix[hit_flags] | found_vectors = matrix[hit_flags] | ||||
if len(found_vectors)!=0: | |||||
if len(found_vectors) != 0: | |||||
mean = np.mean(found_vectors, axis=0, keepdims=True) | mean = np.mean(found_vectors, axis=0, keepdims=True) | ||||
std = np.std(found_vectors, axis=0, keepdims=True) | std = np.std(found_vectors, axis=0, keepdims=True) | ||||
unfound_vec_num = len(vocab) - total_hits | unfound_vec_num = len(vocab) - total_hits | ||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean | |||||
matrix[hit_flags==False] = r_vecs | |||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||||
matrix[hit_flags == False] = r_vecs | |||||
if normalize: | if normalize: | ||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | ||||
return matrix | return matrix | ||||
@staticmethod | @staticmethod | ||||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | ||||
error='ignore'): | |||||
error='ignore'): | |||||
""" | """ | ||||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | ||||
@@ -96,35 +100,35 @@ class EmbedLoader(BaseLoader): | |||||
vec_dict = {} | vec_dict = {} | ||||
found_unknown = False | found_unknown = False | ||||
found_pad = False | found_pad = False | ||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | with open(embed_filepath, 'r', encoding='utf-8') as f: | ||||
line = f.readline() | line = f.readline() | ||||
start = 1 | start = 1 | ||||
dim = -1 | dim = -1 | ||||
if len(line.strip().split())!=2: | |||||
if len(line.strip().split()) != 2: | |||||
f.seek(0) | f.seek(0) | ||||
start = 0 | start = 0 | ||||
for idx, line in enumerate(f, start=start): | for idx, line in enumerate(f, start=start): | ||||
try: | try: | ||||
parts = line.strip().split() | parts = line.strip().split() | ||||
word = parts[0] | word = parts[0] | ||||
if dim==-1: | |||||
dim = len(parts)-1 | |||||
if dim == -1: | |||||
dim = len(parts) - 1 | |||||
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | ||||
vec_dict[word] = vec | vec_dict[word] = vec | ||||
vocab.add_word(word) | vocab.add_word(word) | ||||
if unknown is not None and unknown==word: | |||||
if unknown is not None and unknown == word: | |||||
found_unknown = True | found_unknown = True | ||||
if found_pad is not None and padding==word: | |||||
if found_pad is not None and padding == word: | |||||
found_pad = True | found_pad = True | ||||
except Exception as e: | except Exception as e: | ||||
if error=='ignore': | |||||
if error == 'ignore': | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | warnings.warn("Error occurred at the {} line.".format(idx)) | ||||
pass | pass | ||||
else: | else: | ||||
print("Error occurred at the {} line.".format(idx)) | print("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
if dim==-1: | |||||
if dim == -1: | |||||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | raise RuntimeError("{} is an empty file.".format(embed_filepath)) | ||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | matrix = np.random.randn(len(vocab), dim).astype(dtype) | ||||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | ||||
@@ -133,19 +137,19 @@ class EmbedLoader(BaseLoader): | |||||
start_idx += 1 | start_idx += 1 | ||||
if unknown is not None: | if unknown is not None: | ||||
start_idx += 1 | start_idx += 1 | ||||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | ||||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | std = np.std(matrix[start_idx:], axis=0, keepdims=True) | ||||
if (unknown is not None and not found_unknown): | if (unknown is not None and not found_unknown): | ||||
matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean | |||||
matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean | |||||
if (padding is not None and not found_pad): | if (padding is not None and not found_pad): | ||||
matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean | |||||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | |||||
for key, vec in vec_dict.items(): | for key, vec in vec_dict.items(): | ||||
index = vocab.to_index(key) | index = vocab.to_index(key) | ||||
matrix[index] = vec | matrix[index] = vec | ||||
if normalize: | if normalize: | ||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | ||||
return matrix, vocab | return matrix, vocab |
@@ -1,6 +1,11 @@ | |||||
""" | """ | ||||
用于载入和保存模型 | 用于载入和保存模型 | ||||
""" | """ | ||||
__all__ = [ | |||||
"ModelLoader", | |||||
"ModelSaver" | |||||
] | |||||
import torch | import torch | ||||
from .base_loader import BaseLoader | from .base_loader import BaseLoader | ||||
@@ -12,10 +17,10 @@ class ModelLoader(BaseLoader): | |||||
用于读取模型 | 用于读取模型 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(ModelLoader, self).__init__() | super(ModelLoader, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def load_pytorch(empty_model, model_path): | def load_pytorch(empty_model, model_path): | ||||
""" | """ | ||||
@@ -25,7 +30,7 @@ class ModelLoader(BaseLoader): | |||||
:param str model_path: 模型保存的路径 | :param str model_path: 模型保存的路径 | ||||
""" | """ | ||||
empty_model.load_state_dict(torch.load(model_path)) | empty_model.load_state_dict(torch.load(model_path)) | ||||
@staticmethod | @staticmethod | ||||
def load_pytorch_model(model_path): | def load_pytorch_model(model_path): | ||||
""" | """ | ||||
@@ -48,14 +53,14 @@ class ModelSaver(object): | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | """ | ||||
:param save_path: 模型保存的路径 | :param save_path: 模型保存的路径 | ||||
""" | """ | ||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model, param_only=True): | def save_pytorch(self, model, param_only=True): | ||||
""" | """ | ||||
把 PyTorch 模型存入 ".pkl" 文件 | 把 PyTorch 模型存入 ".pkl" 文件 | ||||
@@ -7,7 +7,23 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||||
""" | """ | ||||
__all__ = ["CNNText", "SeqLabeling", "ESIM", "STSeqLabel", "AdvSeqLabel", "STNLICls", "STSeqCls"] | |||||
__all__ = [ | |||||
"CNNText", | |||||
"SeqLabeling", | |||||
"AdvSeqLabel", | |||||
"ESIM", | |||||
"StarTransEnc", | |||||
"STSeqLabel", | |||||
"STNLICls", | |||||
"STSeqCls", | |||||
"BiaffineParser", | |||||
"GraphParser" | |||||
] | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | ||||
BertForTokenClassification | BertForTokenClassification | ||||
@@ -15,4 +31,4 @@ from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .cnn_text_classification import CNNText | from .cnn_text_classification import CNNText | ||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | from .sequence_labeling import SeqLabeling, AdvSeqLabel | ||||
from .snli import ESIM | from .snli import ESIM | ||||
from .star_transformer import STSeqCls, STNLICls, STSeqLabel | |||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel |
@@ -1,18 +1,18 @@ | |||||
import torch | import torch | ||||
from ..modules.decoder.MLP import MLP | |||||
from ..modules.decoder.mlp import MLP | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
"""Base PyTorch model for all models. | """Base PyTorch model for all models. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(BaseModel, self).__init__() | super(BaseModel, self).__init__() | ||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
pass | pass | ||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -21,9 +21,9 @@ class NaiveClassifier(BaseModel): | |||||
def __init__(self, in_feature_dim, out_feature_dim): | def __init__(self, in_feature_dim, out_feature_dim): | ||||
super(NaiveClassifier, self).__init__() | super(NaiveClassifier, self).__init__() | ||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | ||||
def forward(self, x): | def forward(self, x): | ||||
return {"predict": torch.sigmoid(self.mlp(x))} | return {"predict": torch.sigmoid(self.mlp(x))} | ||||
def predict(self, x): | def predict(self, x): | ||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} |
@@ -1,11 +1,17 @@ | |||||
"""Biaffine Dependency Parser 的 Pytorch 实现. | |||||
""" | """ | ||||
from collections import defaultdict | |||||
Biaffine Dependency Parser 的 Pytorch 实现. | |||||
""" | |||||
__all__ = [ | |||||
"BiaffineParser", | |||||
"GraphParser" | |||||
] | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from collections import defaultdict | |||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from ..core.losses import LossFunc | from ..core.losses import LossFunc | ||||
@@ -18,6 +24,7 @@ from ..modules.utils import get_embeddings | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
def _mst(scores): | def _mst(scores): | ||||
""" | """ | ||||
with some modification to support parser output for MST decoding | with some modification to support parser output for MST decoding | ||||
@@ -44,7 +51,7 @@ def _mst(scores): | |||||
scores[roots, new_heads] / root_scores)] | scores[roots, new_heads] / root_scores)] | ||||
heads[roots] = new_heads | heads[roots] = new_heads | ||||
heads[new_root] = 0 | heads[new_root] = 0 | ||||
edges = defaultdict(set) | edges = defaultdict(set) | ||||
vertices = set((0,)) | vertices = set((0,)) | ||||
for dep, head in enumerate(heads[tokens]): | for dep, head in enumerate(heads[tokens]): | ||||
@@ -73,7 +80,7 @@ def _mst(scores): | |||||
heads[changed_cycle] = new_head | heads[changed_cycle] = new_head | ||||
edges[new_head].add(changed_cycle) | edges[new_head].add(changed_cycle) | ||||
edges[old_head].remove(changed_cycle) | edges[old_head].remove(changed_cycle) | ||||
return heads | return heads | ||||
@@ -88,7 +95,7 @@ def _find_cycle(vertices, edges): | |||||
_lowlinks = {} | _lowlinks = {} | ||||
_onstack = defaultdict(lambda: False) | _onstack = defaultdict(lambda: False) | ||||
_SCCs = [] | _SCCs = [] | ||||
def _strongconnect(v): | def _strongconnect(v): | ||||
nonlocal _index | nonlocal _index | ||||
_indices[v] = _index | _indices[v] = _index | ||||
@@ -96,28 +103,28 @@ def _find_cycle(vertices, edges): | |||||
_index += 1 | _index += 1 | ||||
_stack.append(v) | _stack.append(v) | ||||
_onstack[v] = True | _onstack[v] = True | ||||
for w in edges[v]: | for w in edges[v]: | ||||
if w not in _indices: | if w not in _indices: | ||||
_strongconnect(w) | _strongconnect(w) | ||||
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | _lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | ||||
elif _onstack[w]: | elif _onstack[w]: | ||||
_lowlinks[v] = min(_lowlinks[v], _indices[w]) | _lowlinks[v] = min(_lowlinks[v], _indices[w]) | ||||
if _lowlinks[v] == _indices[v]: | if _lowlinks[v] == _indices[v]: | ||||
SCC = set() | SCC = set() | ||||
while True: | while True: | ||||
w = _stack.pop() | w = _stack.pop() | ||||
_onstack[w] = False | _onstack[w] = False | ||||
SCC.add(w) | SCC.add(w) | ||||
if not(w != v): | |||||
if not (w != v): | |||||
break | break | ||||
_SCCs.append(SCC) | _SCCs.append(SCC) | ||||
for v in vertices: | for v in vertices: | ||||
if v not in _indices: | if v not in _indices: | ||||
_strongconnect(v) | _strongconnect(v) | ||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | return [SCC for SCC in _SCCs if len(SCC) > 1] | ||||
@@ -125,9 +132,10 @@ class GraphParser(BaseModel): | |||||
""" | """ | ||||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | 基于图的parser base class, 支持贪婪解码和最大生成树解码 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(GraphParser, self).__init__() | super(GraphParser, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def greedy_decoder(arc_matrix, mask=None): | def greedy_decoder(arc_matrix, mask=None): | ||||
""" | """ | ||||
@@ -146,7 +154,7 @@ class GraphParser(BaseModel): | |||||
if mask is not None: | if mask is not None: | ||||
heads *= mask.long() | heads *= mask.long() | ||||
return heads | return heads | ||||
@staticmethod | @staticmethod | ||||
def mst_decoder(arc_matrix, mask=None): | def mst_decoder(arc_matrix, mask=None): | ||||
""" | """ | ||||
@@ -176,6 +184,7 @@ class ArcBiaffine(nn.Module): | |||||
:param hidden_size: 输入的特征维度 | :param hidden_size: 输入的特征维度 | ||||
:param bias: 是否使用bias. Default: ``True`` | :param bias: 是否使用bias. Default: ``True`` | ||||
""" | """ | ||||
def __init__(self, hidden_size, bias=True): | def __init__(self, hidden_size, bias=True): | ||||
super(ArcBiaffine, self).__init__() | super(ArcBiaffine, self).__init__() | ||||
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | ||||
@@ -185,7 +194,7 @@ class ArcBiaffine(nn.Module): | |||||
else: | else: | ||||
self.register_parameter("bias", None) | self.register_parameter("bias", None) | ||||
initial_parameter(self) | initial_parameter(self) | ||||
def forward(self, head, dep): | def forward(self, head, dep): | ||||
""" | """ | ||||
@@ -209,11 +218,12 @@ class LabelBilinear(nn.Module): | |||||
:param num_label: 边类别的个数 | :param num_label: 边类别的个数 | ||||
:param bias: 是否使用bias. Default: ``True`` | :param bias: 是否使用bias. Default: ``True`` | ||||
""" | """ | ||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | ||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
""" | """ | ||||
@@ -225,13 +235,13 @@ class LabelBilinear(nn.Module): | |||||
output += self.lin(torch.cat([x1, x2], dim=2)) | output += self.lin(torch.cat([x1, x2], dim=2)) | ||||
return output | return output | ||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
""" | """ | ||||
别名::class:`fastNLP.models.BiaffineParser` :class:`fastNLP.models.baffine_parser.BiaffineParser` | 别名::class:`fastNLP.models.BiaffineParser` :class:`fastNLP.models.baffine_parser.BiaffineParser` | ||||
Biaffine Dependency Parser 实现. | Biaffine Dependency Parser 实现. | ||||
论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||||
<https://arxiv.org/abs/1611.01734>`_ . | |||||
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | |||||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | ||||
@@ -248,18 +258,19 @@ class BiaffineParser(GraphParser): | |||||
:param use_greedy_infer: 是否在inference时使用贪心算法. | :param use_greedy_infer: 是否在inference时使用贪心算法. | ||||
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
init_embed, | |||||
pos_vocab_size, | |||||
pos_emb_dim, | |||||
num_label, | |||||
rnn_layers=1, | |||||
rnn_hidden_size=200, | |||||
arc_mlp_size=100, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | |||||
init_embed, | |||||
pos_vocab_size, | |||||
pos_emb_dim, | |||||
num_label, | |||||
rnn_layers=1, | |||||
rnn_hidden_size=200, | |||||
arc_mlp_size=100, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | |||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | word_hid_dim = pos_hid_dim = rnn_hidden_size | ||||
@@ -295,20 +306,20 @@ class BiaffineParser(GraphParser): | |||||
if (d_k * n_head) != rnn_out_size: | if (d_k * n_head) != rnn_out_size: | ||||
raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | ||||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | self.position_emb = nn.Embedding(num_embeddings=self.max_len, | ||||
embedding_dim=rnn_out_size,) | |||||
embedding_dim=rnn_out_size, ) | |||||
self.encoder = TransformerEncoder(num_layers=rnn_layers, | self.encoder = TransformerEncoder(num_layers=rnn_layers, | ||||
model_size=rnn_out_size, | model_size=rnn_out_size, | ||||
inner_size=1024, | inner_size=1024, | ||||
key_size=d_k, | key_size=d_k, | ||||
value_size=d_v, | value_size=d_v, | ||||
num_head=n_head, | num_head=n_head, | ||||
dropout=dropout,) | |||||
dropout=dropout, ) | |||||
else: | else: | ||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | raise ValueError('unsupported encoder type: {}'.format(encoder)) | ||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | ||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout), ) | |||||
self.arc_mlp_size = arc_mlp_size | self.arc_mlp_size = arc_mlp_size | ||||
self.label_mlp_size = label_mlp_size | self.label_mlp_size = label_mlp_size | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
@@ -316,7 +327,7 @@ class BiaffineParser(GraphParser): | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
self.reset_parameters() | self.reset_parameters() | ||||
self.dropout = dropout | self.dropout = dropout | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
for m in self.modules(): | for m in self.modules(): | ||||
if isinstance(m, nn.Embedding): | if isinstance(m, nn.Embedding): | ||||
@@ -327,7 +338,7 @@ class BiaffineParser(GraphParser): | |||||
else: | else: | ||||
for p in m.parameters(): | for p in m.parameters(): | ||||
nn.init.normal_(p, 0, 0.1) | nn.init.normal_(p, 0, 0.1) | ||||
def forward(self, words1, words2, seq_len, target1=None): | def forward(self, words1, words2, seq_len, target1=None): | ||||
"""模型forward阶段 | """模型forward阶段 | ||||
@@ -337,50 +348,52 @@ class BiaffineParser(GraphParser): | |||||
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | :param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | ||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | ||||
Default: ``None`` | Default: ``None`` | ||||
:return dict: parsing结果:: | |||||
:return dict: parsing | |||||
结果:: | |||||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
batch_size, length = words1.shape | batch_size, length = words1.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
mask = seq_len_to_mask(seq_len).long() | mask = seq_len_to_mask(seq_len).long() | ||||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||||
# encoder, extract features | # encoder, extract features | ||||
if self.encoder_name.endswith('lstm'): | if self.encoder_name.endswith('lstm'): | ||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | ||||
x = x[sort_idx] | x = x[sort_idx] | ||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | ||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
feat = feat[unsort_idx] | feat = feat[unsort_idx] | ||||
else: | else: | ||||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:] | |||||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :] | |||||
x = x + self.position_emb(seq_range) | x = x + self.position_emb(seq_range) | ||||
feat = self.encoder(x, mask.float()) | feat = self.encoder(x, mask.float()) | ||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
feat = self.mlp(feat) | feat = self.mlp(feat) | ||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | ||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz] | |||||
label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:] | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if target1 is None or not self.training: | if target1 is None or not self.training: | ||||
# use greedy decoding in training | # use greedy decoding in training | ||||
@@ -390,22 +403,22 @@ class BiaffineParser(GraphParser): | |||||
heads = self.mst_decoder(arc_pred, mask) | heads = self.mst_decoder(arc_pred, mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | |||||
assert self.training # must be training mode | |||||
if target1 is None: | if target1 is None: | ||||
heads = self.greedy_decoder(arc_pred, mask) | heads = self.greedy_decoder(arc_pred, mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | head_pred = None | ||||
heads = target1 | heads = target1 | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | ||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||||
res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} | res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} | ||||
if head_pred is not None: | if head_pred is not None: | ||||
res_dict[C.OUTPUTS(2)] = head_pred | res_dict[C.OUTPUTS(2)] = head_pred | ||||
return res_dict | return res_dict | ||||
@staticmethod | @staticmethod | ||||
def loss(pred1, pred2, target1, target2, seq_len): | def loss(pred1, pred2, target1, target2, seq_len): | ||||
""" | """ | ||||
@@ -418,7 +431,7 @@ class BiaffineParser(GraphParser): | |||||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | :param seq_len: [batch_size, seq_len] 真实目标的长度 | ||||
:return loss: scalar | :return loss: scalar | ||||
""" | """ | ||||
batch_size, length, _ = pred1.shape | batch_size, length, _ = pred1.shape | ||||
mask = seq_len_to_mask(seq_len) | mask = seq_len_to_mask(seq_len) | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
@@ -430,24 +443,26 @@ class BiaffineParser(GraphParser): | |||||
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | ||||
arc_loss = arc_logits[batch_index, child_index, target1] | arc_loss = arc_logits[batch_index, child_index, target1] | ||||
label_loss = label_logits[batch_index, child_index, target2] | label_loss = label_logits[batch_index, child_index, target2] | ||||
byte_mask = flip_mask.byte() | byte_mask = flip_mask.byte() | ||||
arc_loss.masked_fill_(byte_mask, 0) | arc_loss.masked_fill_(byte_mask, 0) | ||||
label_loss.masked_fill_(byte_mask, 0) | label_loss.masked_fill_(byte_mask, 0) | ||||
arc_nll = -arc_loss.mean() | arc_nll = -arc_loss.mean() | ||||
label_nll = -label_loss.mean() | label_nll = -label_loss.mean() | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, words1, words2, seq_len): | def predict(self, words1, words2, seq_len): | ||||
"""模型预测API | """模型预测API | ||||
:param words1: [batch_size, seq_len] 输入word序列 | :param words1: [batch_size, seq_len] 输入word序列 | ||||
:param words2: [batch_size, seq_len] 输入pos序列 | :param words2: [batch_size, seq_len] 输入pos序列 | ||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | :param seq_len: [batch_size, seq_len] 输入序列长度 | ||||
:return dict: parsing结果:: | |||||
:return dict: parsing | |||||
结果:: | |||||
pred1: [batch_size, seq_len] heads的预测结果 | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
pred1: [batch_size, seq_len] heads的预测结果 | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
""" | """ | ||||
res = self(words1, words2, seq_len) | res = self(words1, words2, seq_len) | ||||
output = {} | output = {} | ||||
@@ -470,6 +485,7 @@ class ParserLoss(LossFunc): | |||||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | :param seq_len: [batch_size, seq_len] 真实目标的长度 | ||||
:return loss: scalar | :return loss: scalar | ||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, | def __init__(self, pred1=None, pred2=None, | ||||
target1=None, target2=None, | target1=None, target2=None, | ||||
seq_len=None): | seq_len=None): | ||||
@@ -497,9 +513,10 @@ class ParserMetric(MetricBase): | |||||
UAS: 不带label时, 边预测的准确率 | UAS: 不带label时, 边预测的准确率 | ||||
LAS: 同时预测边和label的准确率 | LAS: 同时预测边和label的准确率 | ||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, | def __init__(self, pred1=None, pred2=None, | ||||
target1=None, target2=None, seq_len=None): | target1=None, target2=None, seq_len=None): | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred1=pred1, pred2=pred2, | self._init_param_map(pred1=pred1, pred2=pred2, | ||||
target1=target1, target2=target2, | target1=target1, target2=target2, | ||||
@@ -507,13 +524,13 @@ class ParserMetric(MetricBase): | |||||
self.num_arc = 0 | self.num_arc = 0 | ||||
self.num_label = 0 | self.num_label = 0 | ||||
self.num_sample = 0 | self.num_sample = 0 | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample} | |||||
res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} | |||||
if reset: | if reset: | ||||
self.num_sample = self.num_label = self.num_arc = 0 | self.num_sample = self.num_label = self.num_arc = 0 | ||||
return res | return res | ||||
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | ||||
"""Evaluate the performance of prediction. | """Evaluate the performance of prediction. | ||||
""" | """ | ||||
@@ -522,7 +539,7 @@ class ParserMetric(MetricBase): | |||||
else: | else: | ||||
seq_mask = seq_len_to_mask(seq_len.long()).long() | seq_mask = seq_len_to_mask(seq_len.long()).long() | ||||
# mask out <root> tag | # mask out <root> tag | ||||
seq_mask[:,0] = 0 | |||||
seq_mask[:, 0] = 0 | |||||
head_pred_correct = (pred1 == target1).long() * seq_mask | head_pred_correct = (pred1 == target1).long() * seq_mask | ||||
label_pred_correct = (pred2 == target2).long() * head_pred_correct | label_pred_correct = (pred2 == target2).long() * head_pred_correct | ||||
self.num_arc += head_pred_correct.sum().item() | self.num_arc += head_pred_correct.sum().item() | ||||
@@ -1,10 +1,11 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
__all__ = [ | |||||
"CNNText" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..core.const import Const as C | |||||
from ..core.const import Const as C | |||||
from ..modules import encoder | from ..modules import encoder | ||||
@@ -23,7 +24,7 @@ class CNNText(torch.nn.Module): | |||||
:param int padding: 对句子前后的pad的大小, 用0填充。 | :param int padding: 对句子前后的pad的大小, 用0填充。 | ||||
:param float dropout: Dropout的大小 | :param float dropout: Dropout的大小 | ||||
""" | """ | ||||
def __init__(self, init_embed, | def __init__(self, init_embed, | ||||
num_classes, | num_classes, | ||||
kernel_nums=(3, 4, 5), | kernel_nums=(3, 4, 5), | ||||
@@ -31,7 +32,7 @@ class CNNText(torch.nn.Module): | |||||
padding=0, | padding=0, | ||||
dropout=0.5): | dropout=0.5): | ||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
# no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
self.embed = encoder.Embedding(init_embed) | self.embed = encoder.Embedding(init_embed) | ||||
self.conv_pool = encoder.ConvMaxpool( | self.conv_pool = encoder.ConvMaxpool( | ||||
@@ -41,7 +42,7 @@ class CNNText(torch.nn.Module): | |||||
padding=padding) | padding=padding) | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | self.fc = nn.Linear(sum(kernel_nums), num_classes) | ||||
def forward(self, words, seq_len=None): | def forward(self, words, seq_len=None): | ||||
""" | """ | ||||
@@ -54,7 +55,7 @@ class CNNText(torch.nn.Module): | |||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return {C.OUTPUT: x} | return {C.OUTPUT: x} | ||||
def predict(self, words, seq_len=None): | def predict(self, words, seq_len=None): | ||||
""" | """ | ||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index | ||||
@@ -5,6 +5,7 @@ import os | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from . import enas_utils as utils | from . import enas_utils as utils | ||||
from .enas_utils import Node | from .enas_utils import Node | ||||
@@ -1,17 +1,19 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""Module containing the shared RNN model.""" | |||||
import numpy as np | |||||
""" | |||||
Module containing the shared RNN model. | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
import collections | import collections | ||||
import numpy as np | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
from . import enas_utils as utils | from . import enas_utils as utils | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | def _get_dropped_weights(w_raw, dropout_p, is_training): | ||||
"""Drops out weights to implement DropConnect. | """Drops out weights to implement DropConnect. | ||||
@@ -35,12 +37,13 @@ def _get_dropped_weights(w_raw, dropout_p, is_training): | |||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | ||||
""" | """ | ||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | ||||
if isinstance(dropped_w, torch.nn.Parameter): | if isinstance(dropped_w, torch.nn.Parameter): | ||||
dropped_w = dropped_w.clone() | dropped_w = dropped_w.clone() | ||||
return dropped_w | return dropped_w | ||||
class EmbeddingDropout(torch.nn.Embedding): | class EmbeddingDropout(torch.nn.Embedding): | ||||
"""Class for dropping out embeddings by zero'ing out parameters in the | """Class for dropping out embeddings by zero'ing out parameters in the | ||||
embedding matrix. | embedding matrix. | ||||
@@ -53,6 +56,7 @@ class EmbeddingDropout(torch.nn.Embedding): | |||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | ||||
Networks', (Gal and Ghahramani, 2016). | Networks', (Gal and Ghahramani, 2016). | ||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
num_embeddings, | num_embeddings, | ||||
embedding_dim, | embedding_dim, | ||||
@@ -83,14 +87,14 @@ class EmbeddingDropout(torch.nn.Embedding): | |||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | ||||
'and < 1.0') | 'and < 1.0') | ||||
self.scale = scale | self.scale = scale | ||||
def forward(self, inputs): # pylint:disable=arguments-differ | def forward(self, inputs): # pylint:disable=arguments-differ | ||||
"""Embeds `inputs` with the dropped out embedding weight matrix.""" | """Embeds `inputs` with the dropped out embedding weight matrix.""" | ||||
if self.training: | if self.training: | ||||
dropout = self.dropout | dropout = self.dropout | ||||
else: | else: | ||||
dropout = 0 | dropout = 0 | ||||
if dropout: | if dropout: | ||||
mask = self.weight.data.new(self.weight.size(0), 1) | mask = self.weight.data.new(self.weight.size(0), 1) | ||||
mask.bernoulli_(1 - dropout) | mask.bernoulli_(1 - dropout) | ||||
@@ -101,7 +105,7 @@ class EmbeddingDropout(torch.nn.Embedding): | |||||
masked_weight = self.weight | masked_weight = self.weight | ||||
if self.scale and self.scale != 1: | if self.scale and self.scale != 1: | ||||
masked_weight = masked_weight * self.scale | masked_weight = masked_weight * self.scale | ||||
return F.embedding(inputs, | return F.embedding(inputs, | ||||
masked_weight, | masked_weight, | ||||
max_norm=self.max_norm, | max_norm=self.max_norm, | ||||
@@ -114,7 +118,7 @@ class LockedDropout(nn.Module): | |||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def forward(self, x, dropout=0.5): | def forward(self, x, dropout=0.5): | ||||
if not self.training or not dropout: | if not self.training or not dropout: | ||||
return x | return x | ||||
@@ -126,11 +130,12 @@ class LockedDropout(nn.Module): | |||||
class ENASModel(BaseModel): | class ENASModel(BaseModel): | ||||
"""Shared RNN model.""" | """Shared RNN model.""" | ||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | ||||
super(ENASModel, self).__init__() | super(ENASModel, self).__init__() | ||||
self.use_cuda = cuda | self.use_cuda = cuda | ||||
self.shared_hid = shared_hid | self.shared_hid = shared_hid | ||||
self.num_blocks = num_blocks | self.num_blocks = num_blocks | ||||
self.decoder = nn.Linear(self.shared_hid, num_classes) | self.decoder = nn.Linear(self.shared_hid, num_classes) | ||||
@@ -139,16 +144,16 @@ class ENASModel(BaseModel): | |||||
dropout=0.1) | dropout=0.1) | ||||
self.lockdrop = LockedDropout() | self.lockdrop = LockedDropout() | ||||
self.dag = None | self.dag = None | ||||
# Tie weights | # Tie weights | ||||
# self.decoder.weight = self.encoder.weight | # self.decoder.weight = self.encoder.weight | ||||
# Since W^{x, c} and W^{h, c} are always summed, there | # Since W^{x, c} and W^{h, c} are always summed, there | ||||
# is no point duplicating their bias offset parameter. Likewise for | # is no point duplicating their bias offset parameter. Likewise for | ||||
# W^{x, h} and W^{h, h}. | # W^{x, h} and W^{h, h}. | ||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid) | self.w_xc = nn.Linear(shared_embed, self.shared_hid) | ||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid) | self.w_xh = nn.Linear(shared_embed, self.shared_hid) | ||||
# The raw weights are stored here because the hidden-to-hidden weights | # The raw weights are stored here because the hidden-to-hidden weights | ||||
# are weight dropped on the forward pass. | # are weight dropped on the forward pass. | ||||
self.w_hc_raw = torch.nn.Parameter( | self.w_hc_raw = torch.nn.Parameter( | ||||
@@ -157,10 +162,10 @@ class ENASModel(BaseModel): | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | torch.Tensor(self.shared_hid, self.shared_hid)) | ||||
self.w_hc = None | self.w_hc = None | ||||
self.w_hh = None | self.w_hh = None | ||||
self.w_h = collections.defaultdict(dict) | self.w_h = collections.defaultdict(dict) | ||||
self.w_c = collections.defaultdict(dict) | self.w_c = collections.defaultdict(dict) | ||||
for idx in range(self.num_blocks): | for idx in range(self.num_blocks): | ||||
for jdx in range(idx + 1, self.num_blocks): | for jdx in range(idx + 1, self.num_blocks): | ||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | ||||
@@ -169,48 +174,47 @@ class ENASModel(BaseModel): | |||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | ||||
self.shared_hid, | self.shared_hid, | ||||
bias=False) | bias=False) | ||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx] | self._w_h = nn.ModuleList([self.w_h[idx][jdx] | ||||
for idx in self.w_h | for idx in self.w_h | ||||
for jdx in self.w_h[idx]]) | for jdx in self.w_h[idx]]) | ||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx] | self._w_c = nn.ModuleList([self.w_c[idx][jdx] | ||||
for idx in self.w_c | for idx in self.w_c | ||||
for jdx in self.w_c[idx]]) | for jdx in self.w_c[idx]]) | ||||
self.batch_norm = None | self.batch_norm = None | ||||
# if args.mode == 'train': | # if args.mode == 'train': | ||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) | # self.batch_norm = nn.BatchNorm1d(self.shared_hid) | ||||
# else: | # else: | ||||
# self.batch_norm = None | # self.batch_norm = None | ||||
self.reset_parameters() | self.reset_parameters() | ||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | ||||
def setDAG(self, dag): | def setDAG(self, dag): | ||||
if self.dag is None: | if self.dag is None: | ||||
self.dag = dag | self.dag = dag | ||||
def forward(self, word_seq, hidden=None): | def forward(self, word_seq, hidden=None): | ||||
inputs = torch.transpose(word_seq, 0, 1) | inputs = torch.transpose(word_seq, 0, 1) | ||||
time_steps = inputs.size(0) | time_steps = inputs.size(0) | ||||
batch_size = inputs.size(1) | batch_size = inputs.size(1) | ||||
self.w_hh = _get_dropped_weights(self.w_hh_raw, | self.w_hh = _get_dropped_weights(self.w_hh_raw, | ||||
0.5, | 0.5, | ||||
self.training) | self.training) | ||||
self.w_hc = _get_dropped_weights(self.w_hc_raw, | self.w_hc = _get_dropped_weights(self.w_hc_raw, | ||||
0.5, | 0.5, | ||||
self.training) | self.training) | ||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | # hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | ||||
hidden = self.static_init_hidden[batch_size] | hidden = self.static_init_hidden[batch_size] | ||||
embed = self.encoder(inputs) | embed = self.encoder(inputs) | ||||
embed = self.lockdrop(embed, 0.65 if self.training else 0) | embed = self.lockdrop(embed, 0.65 if self.training else 0) | ||||
# The norm of hidden states are clipped here because | # The norm of hidden states are clipped here because | ||||
# otherwise ENAS is especially prone to exploding activations on the | # otherwise ENAS is especially prone to exploding activations on the | ||||
# forward pass. This could probably be fixed in a more elegant way, but | # forward pass. This could probably be fixed in a more elegant way, but | ||||
@@ -226,7 +230,7 @@ class ENASModel(BaseModel): | |||||
for step in range(time_steps): | for step in range(time_steps): | ||||
x_t = embed[step] | x_t = embed[step] | ||||
logit, hidden = self.cell(x_t, hidden, self.dag) | logit, hidden = self.cell(x_t, hidden, self.dag) | ||||
hidden_norms = hidden.norm(dim=-1) | hidden_norms = hidden.norm(dim=-1) | ||||
max_norm = 25.0 | max_norm = 25.0 | ||||
if hidden_norms.data.max() > max_norm: | if hidden_norms.data.max() > max_norm: | ||||
@@ -237,60 +241,60 @@ class ENASModel(BaseModel): | |||||
# because the PyTorch slicing and slice assignment is too | # because the PyTorch slicing and slice assignment is too | ||||
# flaky. | # flaky. | ||||
hidden_norms = hidden_norms.data.cpu().numpy() | hidden_norms = hidden_norms.data.cpu().numpy() | ||||
clipped_num += 1 | clipped_num += 1 | ||||
if hidden_norms.max() > max_clipped_norm: | if hidden_norms.max() > max_clipped_norm: | ||||
max_clipped_norm = hidden_norms.max() | max_clipped_norm = hidden_norms.max() | ||||
clip_select = hidden_norms > max_norm | clip_select = hidden_norms > max_norm | ||||
clip_norms = hidden_norms[clip_select] | clip_norms = hidden_norms[clip_select] | ||||
mask = np.ones(hidden.size()) | mask = np.ones(hidden.size()) | ||||
normalizer = max_norm/clip_norms | |||||
normalizer = max_norm / clip_norms | |||||
normalizer = normalizer[:, np.newaxis] | normalizer = normalizer[:, np.newaxis] | ||||
mask[clip_select] = normalizer | mask[clip_select] = normalizer | ||||
if self.use_cuda: | if self.use_cuda: | ||||
hidden *= torch.autograd.Variable( | hidden *= torch.autograd.Variable( | ||||
torch.FloatTensor(mask).cuda(), requires_grad=False) | torch.FloatTensor(mask).cuda(), requires_grad=False) | ||||
else: | else: | ||||
hidden *= torch.autograd.Variable( | hidden *= torch.autograd.Variable( | ||||
torch.FloatTensor(mask), requires_grad=False) | |||||
torch.FloatTensor(mask), requires_grad=False) | |||||
logits.append(logit) | logits.append(logit) | ||||
h1tohT.append(hidden) | h1tohT.append(hidden) | ||||
h1tohT = torch.stack(h1tohT) | h1tohT = torch.stack(h1tohT) | ||||
output = torch.stack(logits) | output = torch.stack(logits) | ||||
raw_output = output | raw_output = output | ||||
output = self.lockdrop(output, 0.4 if self.training else 0) | output = self.lockdrop(output, 0.4 if self.training else 0) | ||||
#Pooling | |||||
# Pooling | |||||
output = torch.mean(output, 0) | output = torch.mean(output, 0) | ||||
decoded = self.decoder(output) | decoded = self.decoder(output) | ||||
extra_out = {'dropped': decoded, | extra_out = {'dropped': decoded, | ||||
'hiddens': h1tohT, | 'hiddens': h1tohT, | ||||
'raw': raw_output} | 'raw': raw_output} | ||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | ||||
def cell(self, x, h_prev, dag): | def cell(self, x, h_prev, dag): | ||||
"""Computes a single pass through the discovered RNN cell.""" | """Computes a single pass through the discovered RNN cell.""" | ||||
c = {} | c = {} | ||||
h = {} | h = {} | ||||
f = {} | f = {} | ||||
f[0] = self.get_f(dag[-1][0].name) | f[0] = self.get_f(dag[-1][0].name) | ||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | ||||
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0])*h_prev) | |||||
h[0] = (c[0] * f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0]) * h_prev) | |||||
leaf_node_ids = [] | leaf_node_ids = [] | ||||
q = collections.deque() | q = collections.deque() | ||||
q.append(0) | q.append(0) | ||||
# Computes connections from the parent nodes `node_id` | # Computes connections from the parent nodes `node_id` | ||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A | # to their child nodes `next_id` recursively, skipping leaf nodes. A | ||||
# leaf node is a node whose id == `self.num_blocks`. | # leaf node is a node whose id == `self.num_blocks`. | ||||
@@ -306,10 +310,10 @@ class ENASModel(BaseModel): | |||||
while True: | while True: | ||||
if len(q) == 0: | if len(q) == 0: | ||||
break | break | ||||
node_id = q.popleft() | node_id = q.popleft() | ||||
nodes = dag[node_id] | nodes = dag[node_id] | ||||
for next_node in nodes: | for next_node in nodes: | ||||
next_id = next_node.id | next_id = next_node.id | ||||
if next_id == self.num_blocks: | if next_id == self.num_blocks: | ||||
@@ -317,38 +321,38 @@ class ENASModel(BaseModel): | |||||
assert len(nodes) == 1, ('parent of leaf node should have ' | assert len(nodes) == 1, ('parent of leaf node should have ' | ||||
'only one child') | 'only one child') | ||||
continue | continue | ||||
w_h = self.w_h[node_id][next_id] | w_h = self.w_h[node_id][next_id] | ||||
w_c = self.w_c[node_id][next_id] | w_c = self.w_c[node_id][next_id] | ||||
f[next_id] = self.get_f(next_node.name) | f[next_id] = self.get_f(next_node.name) | ||||
c[next_id] = torch.sigmoid(w_c(h[node_id])) | c[next_id] = torch.sigmoid(w_c(h[node_id])) | ||||
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id])*h[node_id]) | |||||
h[next_id] = (c[next_id] * f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id]) * h[node_id]) | |||||
q.append(next_id) | q.append(next_id) | ||||
# Instead of averaging loose ends, perhaps there should | # Instead of averaging loose ends, perhaps there should | ||||
# be a set of separate unshared weights for each "loose" connection | # be a set of separate unshared weights for each "loose" connection | ||||
# between each node in a cell and the output. | # between each node in a cell and the output. | ||||
# | # | ||||
# As it stands, all weights W^h_{ij} are doing double duty by | # As it stands, all weights W^h_{ij} are doing double duty by | ||||
# connecting both from i to j, as well as from i to the output. | # connecting both from i to j, as well as from i to the output. | ||||
# average all the loose ends | # average all the loose ends | ||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | ||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1) | output = torch.mean(torch.stack(leaf_nodes, 2), -1) | ||||
# stabilizing the Updates of omega | # stabilizing the Updates of omega | ||||
if self.batch_norm is not None: | if self.batch_norm is not None: | ||||
output = self.batch_norm(output) | output = self.batch_norm(output) | ||||
return output, h[self.num_blocks - 1] | return output, h[self.num_blocks - 1] | ||||
def init_hidden(self, batch_size): | def init_hidden(self, batch_size): | ||||
zeros = torch.zeros(batch_size, self.shared_hid) | zeros = torch.zeros(batch_size, self.shared_hid) | ||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | ||||
def get_f(self, name): | def get_f(self, name): | ||||
name = name.lower() | name = name.lower() | ||||
if name == 'relu': | if name == 'relu': | ||||
@@ -360,22 +364,21 @@ class ENASModel(BaseModel): | |||||
elif name == 'sigmoid': | elif name == 'sigmoid': | ||||
f = torch.sigmoid | f = torch.sigmoid | ||||
return f | return f | ||||
@property | @property | ||||
def num_parameters(self): | def num_parameters(self): | ||||
def size(p): | def size(p): | ||||
return np.prod(p.size()) | return np.prod(p.size()) | ||||
return sum([size(param) for param in self.parameters()]) | return sum([size(param) for param in self.parameters()]) | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
init_range = 0.025 | init_range = 0.025 | ||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04 | # init_range = 0.025 if self.args.mode == 'train' else 0.04 | ||||
for param in self.parameters(): | for param in self.parameters(): | ||||
param.data.uniform_(-init_range, init_range) | param.data.uniform_(-init_range, init_range) | ||||
self.decoder.bias.data.fill_(0) | self.decoder.bias.data.fill_(0) | ||||
def predict(self, word_seq): | def predict(self, word_seq): | ||||
""" | """ | ||||
@@ -1,12 +1,12 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
import time | |||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
import math | |||||
import numpy as np | import numpy as np | ||||
import time | |||||
import torch | import torch | ||||
import math | |||||
from datetime import datetime, timedelta | |||||
from torch.optim import Adam | |||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -21,8 +21,6 @@ from ..core.utils import _move_dict_value_to_device | |||||
from . import enas_utils as utils | from . import enas_utils as utils | ||||
from ..core.utils import _build_args | from ..core.utils import _build_args | ||||
from torch.optim import Adam | |||||
def _get_no_grad_ctx_mgr(): | def _get_no_grad_ctx_mgr(): | ||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >= | """Returns a the `torch.no_grad` context manager for PyTorch version >= | ||||
@@ -33,6 +31,7 @@ def _get_no_grad_ctx_mgr(): | |||||
class ENASTrainer(Trainer): | class ENASTrainer(Trainer): | ||||
"""A class to wrap training code.""" | """A class to wrap training code.""" | ||||
def __init__(self, train_data, model, controller, **kwargs): | def __init__(self, train_data, model, controller, **kwargs): | ||||
"""Constructor for training algorithm. | """Constructor for training algorithm. | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -45,19 +44,19 @@ class ENASTrainer(Trainer): | |||||
self.controller_step = 0 | self.controller_step = 0 | ||||
self.shared_step = 0 | self.shared_step = 0 | ||||
self.max_length = 35 | self.max_length = 35 | ||||
self.shared = model | self.shared = model | ||||
self.controller = controller | self.controller = controller | ||||
self.shared_optim = Adam( | self.shared_optim = Adam( | ||||
self.shared.parameters(), | self.shared.parameters(), | ||||
lr=20.0, | lr=20.0, | ||||
weight_decay=1e-7) | weight_decay=1e-7) | ||||
self.controller_optim = Adam( | self.controller_optim = Adam( | ||||
self.controller.parameters(), | self.controller.parameters(), | ||||
lr=3.5e-4) | lr=3.5e-4) | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
@@ -82,21 +81,22 @@ class ENASTrainer(Trainer): | |||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self._model_device = self.model.parameters().__next__().device | self._model_device = self.model.parameters().__next__().device | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
print( | |||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf), ) | |||||
results['best_eval'] = self.best_dev_perf | results['best_eval'] = self.best_dev_perf | ||||
results['best_epoch'] = self.best_dev_epoch | results['best_epoch'] = self.best_dev_epoch | ||||
results['best_step'] = self.best_dev_step | results['best_step'] = self.best_dev_step | ||||
@@ -110,9 +110,9 @@ class ENASTrainer(Trainer): | |||||
finally: | finally: | ||||
pass | pass | ||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | ||||
@@ -126,21 +126,21 @@ class ENASTrainer(Trainer): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
for epoch in range(1, self.n_epochs+1): | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | ||||
if epoch == self.n_epochs + 1 - self.final_epochs: | if epoch == self.n_epochs + 1 - self.final_epochs: | ||||
print('Entering the final stage. (Only train the selected structure)') | print('Entering the final stage. (Only train the selected structure)') | ||||
# early stopping | # early stopping | ||||
self.callback_manager.on_epoch_begin() | self.callback_manager.on_epoch_begin() | ||||
# 1. Training the shared parameters omega of the child models | # 1. Training the shared parameters omega of the child models | ||||
self.train_shared(pbar) | self.train_shared(pbar) | ||||
# 2. Training the controller parameters theta | # 2. Training the controller parameters theta | ||||
if not last_stage: | if not last_stage: | ||||
self.train_controller() | self.train_controller() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
@@ -149,16 +149,15 @@ class ENASTrainer(Trainer): | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
total_steps) + \ | total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | pbar.write(eval_str) | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def get_loss(self, inputs, targets, hidden, dags): | def get_loss(self, inputs, targets, hidden, dags): | ||||
"""Computes the loss for the same batch for M models. | """Computes the loss for the same batch for M models. | ||||
@@ -167,7 +166,7 @@ class ENASTrainer(Trainer): | |||||
""" | """ | ||||
if not isinstance(dags, list): | if not isinstance(dags, list): | ||||
dags = [dags] | dags = [dags] | ||||
loss = 0 | loss = 0 | ||||
for dag in dags: | for dag in dags: | ||||
self.shared.setDAG(dag) | self.shared.setDAG(dag) | ||||
@@ -175,14 +174,14 @@ class ENASTrainer(Trainer): | |||||
inputs['hidden'] = hidden | inputs['hidden'] = hidden | ||||
result = self.shared(**inputs) | result = self.shared(**inputs) | ||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | ||||
self.callback_manager.on_loss_begin(targets, result) | self.callback_manager.on_loss_begin(targets, result) | ||||
sample_loss = self._compute_loss(result, targets) | sample_loss = self._compute_loss(result, targets) | ||||
loss += sample_loss | loss += sample_loss | ||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | ||||
return loss, hidden, extra_out | return loss, hidden, extra_out | ||||
def train_shared(self, pbar=None, max_step=None, dag=None): | def train_shared(self, pbar=None, max_step=None, dag=None): | ||||
"""Train the language model for 400 steps of minibatches of 64 | """Train the language model for 400 steps of minibatches of 64 | ||||
examples. | examples. | ||||
@@ -200,9 +199,9 @@ class ENASTrainer(Trainer): | |||||
model = self.shared | model = self.shared | ||||
model.train() | model.train() | ||||
self.controller.eval() | self.controller.eval() | ||||
hidden = self.shared.init_hidden(self.batch_size) | hidden = self.shared.init_hidden(self.batch_size) | ||||
abs_max_grad = 0 | abs_max_grad = 0 | ||||
abs_max_hidden_norm = 0 | abs_max_hidden_norm = 0 | ||||
step = 0 | step = 0 | ||||
@@ -211,15 +210,15 @@ class ENASTrainer(Trainer): | |||||
train_idx = 0 | train_idx = 0 | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | |||||
prefetch=self.prefetch) | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
indices = data_iterator.get_batch_indices() | indices = data_iterator.get_batch_indices() | ||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
# prediction = self._data_forward(self.model, batch_x) | # prediction = self._data_forward(self.model, batch_x) | ||||
dags = self.controller.sample(1) | dags = self.controller.sample(1) | ||||
inputs, targets = batch_x, batch_y | inputs, targets = batch_x, batch_y | ||||
# self.callback_manager.on_loss_begin(batch_y, prediction) | # self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
@@ -228,18 +227,18 @@ class ENASTrainer(Trainer): | |||||
hidden, | hidden, | ||||
dags) | dags) | ||||
hidden.detach_() | hidden.detach_() | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
if (self.step+1) % self.print_every == 0: | |||||
if (self.step + 1) % self.print_every == 0: | |||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | ||||
pbar.update(self.print_every) | pbar.update(self.print_every) | ||||
@@ -255,30 +254,29 @@ class ENASTrainer(Trainer): | |||||
self.shared_step += 1 | self.shared_step += 1 | ||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
def get_reward(self, dag, entropies, hidden, valid_idx=0): | def get_reward(self, dag, entropies, hidden, valid_idx=0): | ||||
"""Computes the perplexity of a single sampled model on a minibatch of | """Computes the perplexity of a single sampled model on a minibatch of | ||||
validation data. | validation data. | ||||
""" | """ | ||||
if not isinstance(entropies, np.ndarray): | if not isinstance(entropies, np.ndarray): | ||||
entropies = entropies.data.cpu().numpy() | entropies = entropies.data.cpu().numpy() | ||||
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | |||||
prefetch=self.prefetch) | |||||
for inputs, targets in data_iterator: | for inputs, targets in data_iterator: | ||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | ||||
valid_loss = utils.to_item(valid_loss.data) | valid_loss = utils.to_item(valid_loss.data) | ||||
valid_ppl = math.exp(valid_loss) | valid_ppl = math.exp(valid_loss) | ||||
R = 80 / valid_ppl | R = 80 / valid_ppl | ||||
rewards = R + 1e-4 * entropies | rewards = R + 1e-4 * entropies | ||||
return rewards, hidden | return rewards, hidden | ||||
def train_controller(self): | def train_controller(self): | ||||
"""Fixes the shared parameters and updates the controller parameters. | """Fixes the shared parameters and updates the controller parameters. | ||||
@@ -296,13 +294,13 @@ class ENASTrainer(Trainer): | |||||
# Why can't we call shared.eval() here? Leads to loss | # Why can't we call shared.eval() here? Leads to loss | ||||
# being uniformly zero for the controller. | # being uniformly zero for the controller. | ||||
# self.shared.eval() | # self.shared.eval() | ||||
avg_reward_base = None | avg_reward_base = None | ||||
baseline = None | baseline = None | ||||
adv_history = [] | adv_history = [] | ||||
entropy_history = [] | entropy_history = [] | ||||
reward_history = [] | reward_history = [] | ||||
hidden = self.shared.init_hidden(self.batch_size) | hidden = self.shared.init_hidden(self.batch_size) | ||||
total_loss = 0 | total_loss = 0 | ||||
valid_idx = 0 | valid_idx = 0 | ||||
@@ -310,7 +308,7 @@ class ENASTrainer(Trainer): | |||||
# sample models | # sample models | ||||
dags, log_probs, entropies = self.controller.sample( | dags, log_probs, entropies = self.controller.sample( | ||||
with_details=True) | with_details=True) | ||||
# calculate reward | # calculate reward | ||||
np_entropies = entropies.data.cpu().numpy() | np_entropies = entropies.data.cpu().numpy() | ||||
# No gradients should be backpropagated to the | # No gradients should be backpropagated to the | ||||
@@ -320,40 +318,39 @@ class ENASTrainer(Trainer): | |||||
np_entropies, | np_entropies, | ||||
hidden, | hidden, | ||||
valid_idx) | valid_idx) | ||||
reward_history.extend(rewards) | reward_history.extend(rewards) | ||||
entropy_history.extend(np_entropies) | entropy_history.extend(np_entropies) | ||||
# moving average baseline | # moving average baseline | ||||
if baseline is None: | if baseline is None: | ||||
baseline = rewards | baseline = rewards | ||||
else: | else: | ||||
decay = 0.95 | decay = 0.95 | ||||
baseline = decay * baseline + (1 - decay) * rewards | baseline = decay * baseline + (1 - decay) * rewards | ||||
adv = rewards - baseline | adv = rewards - baseline | ||||
adv_history.extend(adv) | adv_history.extend(adv) | ||||
# policy loss | # policy loss | ||||
loss = -log_probs*utils.get_variable(adv, | |||||
'cuda' in self.device, | |||||
requires_grad=False) | |||||
loss = -log_probs * utils.get_variable(adv, | |||||
'cuda' in self.device, | |||||
requires_grad=False) | |||||
loss = loss.sum() # or loss.mean() | loss = loss.sum() # or loss.mean() | ||||
# update | # update | ||||
self.controller_optim.zero_grad() | self.controller_optim.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
self.controller_optim.step() | self.controller_optim.step() | ||||
total_loss += utils.to_item(loss.data) | total_loss += utils.to_item(loss.data) | ||||
if ((step % 50) == 0) and (step > 0): | if ((step % 50) == 0) and (step > 0): | ||||
reward_history, adv_history, entropy_history = [], [], [] | reward_history, adv_history, entropy_history = [], [], [] | ||||
total_loss = 0 | total_loss = 0 | ||||
self.controller_step += 1 | self.controller_step += 1 | ||||
# prev_valid_idx = valid_idx | # prev_valid_idx = valid_idx | ||||
# valid_idx = ((valid_idx + self.max_length) % | # valid_idx = ((valid_idx + self.max_length) % | ||||
@@ -362,16 +359,16 @@ class ENASTrainer(Trainer): | |||||
# # validation data, we reset the hidden states. | # # validation data, we reset the hidden states. | ||||
# if prev_valid_idx > valid_idx: | # if prev_valid_idx > valid_idx: | ||||
# hidden = self.shared.init_hidden(self.batch_size) | # hidden = self.shared.init_hidden(self.batch_size) | ||||
def derive(self, sample_num=10, valid_idx=0): | def derive(self, sample_num=10, valid_idx=0): | ||||
"""We are always deriving based on the very first batch | """We are always deriving based on the very first batch | ||||
of validation data? This seems wrong... | of validation data? This seems wrong... | ||||
""" | """ | ||||
hidden = self.shared.init_hidden(self.batch_size) | hidden = self.shared.init_hidden(self.batch_size) | ||||
dags, _, entropies = self.controller.sample(sample_num, | dags, _, entropies = self.controller.sample(sample_num, | ||||
with_details=True) | with_details=True) | ||||
max_R = 0 | max_R = 0 | ||||
best_dag = None | best_dag = None | ||||
for dag in dags: | for dag in dags: | ||||
@@ -379,5 +376,5 @@ class ENASTrainer(Trainer): | |||||
if R.max() > max_R: | if R.max() > max_R: | ||||
max_R = R.max() | max_R = R.max() | ||||
best_dag = dag | best_dag = dag | ||||
self.model.setDAG(best_dag) | self.model.setDAG(best_dag) |
@@ -1,12 +1,9 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
from __future__ import print_function | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import collections | import collections | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
@@ -1,11 +1,19 @@ | |||||
""" | |||||
本模块实现了两种序列标注模型 | |||||
""" | |||||
__all__ = [ | |||||
"SeqLabeling", | |||||
"AdvSeqLabel" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..modules import decoder, encoder | from ..modules import decoder, encoder | ||||
from ..modules.decoder.CRF import allowed_transitions | |||||
from ..modules.decoder.crf import allowed_transitions | |||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from torch import nn | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -27,7 +35,7 @@ class SeqLabeling(BaseModel): | |||||
self.Embedding = encoder.embedding.Embedding(init_embed) | self.Embedding = encoder.embedding.Embedding(init_embed) | ||||
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | ||||
self.Linear = nn.Linear(hidden_size, num_classes) | self.Linear = nn.Linear(hidden_size, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes) | |||||
self.mask = None | self.mask = None | ||||
def forward(self, words, seq_len, target): | def forward(self, words, seq_len, target): | ||||
@@ -133,9 +141,9 @@ class AdvSeqLabel(nn.Module): | |||||
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | ||||
if id2words is None: | if id2words is None: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
else: | else: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
allowed_transitions=allowed_transitions(id2words, | allowed_transitions=allowed_transitions(id2words, | ||||
encoding_type=encoding_type)) | encoding_type=encoding_type)) | ||||
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"ESIM" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -8,7 +12,6 @@ from ..modules import encoder as Encoder | |||||
from ..modules import aggregator as Aggregator | from ..modules import aggregator as Aggregator | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
my_inf = 10e12 | my_inf = 10e12 | ||||
@@ -26,7 +29,7 @@ class ESIM(BaseModel): | |||||
:param int num_classes: 标签数目,默认为3 | :param int num_classes: 标签数目,默认为3 | ||||
:param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 | :param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 | ||||
""" | """ | ||||
def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): | def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): | ||||
super(ESIM, self).__init__() | super(ESIM, self).__init__() | ||||
@@ -35,35 +38,36 @@ class ESIM(BaseModel): | |||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
self.dropout = dropout | self.dropout = dropout | ||||
self.n_labels = num_classes | self.n_labels = num_classes | ||||
self.drop = nn.Dropout(self.dropout) | self.drop = nn.Dropout(self.dropout) | ||||
self.embedding = Encoder.Embedding( | self.embedding = Encoder.Embedding( | ||||
(self.vocab_size, self.embed_dim), dropout=self.dropout, | (self.vocab_size, self.embed_dim), dropout=self.dropout, | ||||
) | ) | ||||
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) | self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) | ||||
self.encoder = Encoder.LSTM( | self.encoder = Encoder.LSTM( | ||||
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, | input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, | ||||
batch_first=True, bidirectional=True | batch_first=True, bidirectional=True | ||||
) | ) | ||||
self.bi_attention = Aggregator.BiAttention() | self.bi_attention = Aggregator.BiAttention() | ||||
self.mean_pooling = Aggregator.AvgPoolWithMask() | self.mean_pooling = Aggregator.AvgPoolWithMask() | ||||
self.max_pooling = Aggregator.MaxPoolWithMask() | self.max_pooling = Aggregator.MaxPoolWithMask() | ||||
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) | self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) | ||||
self.decoder = Encoder.LSTM( | self.decoder = Encoder.LSTM( | ||||
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, | input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, | ||||
batch_first=True, bidirectional=True | batch_first=True, bidirectional=True | ||||
) | ) | ||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | ||||
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | ||||
""" Forward function | """ Forward function | ||||
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 | :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 | ||||
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 | :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 | ||||
:param torch.LongTensor seq_len1: [B] premise的长度 | :param torch.LongTensor seq_len1: [B] premise的长度 | ||||
@@ -71,10 +75,10 @@ class ESIM(BaseModel): | |||||
:param torch.LongTensor target: [B] 真实目标值 | :param torch.LongTensor target: [B] 真实目标值 | ||||
:return: dict prediction: [B, n_labels(N)] 预测结果 | :return: dict prediction: [B, n_labels(N)] 预测结果 | ||||
""" | """ | ||||
premise0 = self.embedding_layer(self.embedding(words1)) | premise0 = self.embedding_layer(self.embedding(words1)) | ||||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | hypothesis0 = self.embedding_layer(self.embedding(words2)) | ||||
if seq_len1 is not None: | if seq_len1 is not None: | ||||
seq_len1 = seq_len_to_mask(seq_len1) | seq_len1 = seq_len_to_mask(seq_len1) | ||||
else: | else: | ||||
@@ -85,55 +89,55 @@ class ESIM(BaseModel): | |||||
else: | else: | ||||
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) | seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) | ||||
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) | seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) | ||||
_BP, _PSL, _HP = premise0.size() | _BP, _PSL, _HP = premise0.size() | ||||
_BH, _HSL, _HH = hypothesis0.size() | _BH, _HSL, _HH = hypothesis0.size() | ||||
_BPL, _PLL = seq_len1.size() | _BPL, _PLL = seq_len1.size() | ||||
_HPL, _HLL = seq_len2.size() | _HPL, _HLL = seq_len2.size() | ||||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | assert _BP == _BH and _BPL == _HPL and _BP == _BPL | ||||
assert _HP == _HH | assert _HP == _HH | ||||
assert _PSL == _PLL and _HSL == _HLL | assert _PSL == _PLL and _HSL == _HLL | ||||
B, PL, H = premise0.size() | B, PL, H = premise0.size() | ||||
B, HL, H = hypothesis0.size() | B, HL, H = hypothesis0.size() | ||||
a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] | a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] | ||||
b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] | b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] | ||||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | ||||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | ||||
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | ||||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ||||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | ||||
f_ma = self.inference_layer(ma) | f_ma = self.inference_layer(ma) | ||||
f_mb = self.inference_layer(mb) | f_mb = self.inference_layer(mb) | ||||
vat = self.decoder(self.drop(f_ma)) | vat = self.decoder(self.drop(f_ma)) | ||||
vbt = self.decoder(self.drop(f_mb)) | vbt = self.decoder(self.drop(f_mb)) | ||||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | ||||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | ||||
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | ||||
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | ||||
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | ||||
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | ||||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | ||||
prediction = torch.tanh(self.output(v)) # prediction: [B, N] | prediction = torch.tanh(self.output(v)) # prediction: [B, N] | ||||
if target is not None: | if target is not None: | ||||
func = nn.CrossEntropyLoss() | func = nn.CrossEntropyLoss() | ||||
loss = func(prediction, target) | loss = func(prediction, target) | ||||
return {Const.OUTPUT: prediction, Const.LOSS: loss} | return {Const.OUTPUT: prediction, Const.LOSS: loss} | ||||
return {Const.OUTPUT: prediction} | return {Const.OUTPUT: prediction} | ||||
def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | ||||
""" Predict function | """ Predict function | ||||
@@ -146,4 +150,3 @@ class ESIM(BaseModel): | |||||
""" | """ | ||||
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] | prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] | ||||
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} | return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} | ||||
@@ -1,17 +1,25 @@ | |||||
"""Star-Transformer 的 一个 Pytorch 实现. | |||||
""" | """ | ||||
Star-Transformer 的 Pytorch 实现。 | |||||
""" | |||||
__all__ = [ | |||||
"StarTransEnc", | |||||
"STNLICls", | |||||
"STSeqCls", | |||||
"STSeqLabel", | |||||
] | |||||
import torch | |||||
from torch import nn | |||||
from ..modules.encoder.star_transformer import StarTransformer | from ..modules.encoder.star_transformer import StarTransformer | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..modules.utils import get_embeddings | from ..modules.utils import get_embeddings | ||||
from ..core.const import Const | from ..core.const import Const | ||||
import torch | |||||
from torch import nn | |||||
class StarTransEnc(nn.Module): | class StarTransEnc(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.models.StarTransEnc` :class:`fastNLP.models.start_transformer.StarTransEnc` | |||||
别名::class:`fastNLP.models.StarTransEnc` :class:`fastNLP.models.star_transformer.StarTransEnc` | |||||
带word embedding的Star-Transformer Encoder | 带word embedding的Star-Transformer Encoder | ||||
@@ -28,6 +36,7 @@ class StarTransEnc(nn.Module): | |||||
:param emb_dropout: 词嵌入的dropout概率. | :param emb_dropout: 词嵌入的dropout概率. | ||||
:param dropout: 模型除词嵌入外的dropout概率. | :param dropout: 模型除词嵌入外的dropout概率. | ||||
""" | """ | ||||
def __init__(self, init_embed, | def __init__(self, init_embed, | ||||
hidden_size, | hidden_size, | ||||
num_layers, | num_layers, | ||||
@@ -47,7 +56,7 @@ class StarTransEnc(nn.Module): | |||||
head_dim=head_dim, | head_dim=head_dim, | ||||
dropout=dropout, | dropout=dropout, | ||||
max_len=max_len) | max_len=max_len) | ||||
def forward(self, x, mask): | def forward(self, x, mask): | ||||
""" | """ | ||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | :param FloatTensor data: [batch, length, hidden] 输入的序列 | ||||
@@ -72,7 +81,7 @@ class _Cls(nn.Module): | |||||
nn.Dropout(dropout), | nn.Dropout(dropout), | ||||
nn.Linear(hid_dim, num_cls), | nn.Linear(hid_dim, num_cls), | ||||
) | ) | ||||
def forward(self, x): | def forward(self, x): | ||||
h = self.fc(x) | h = self.fc(x) | ||||
return h | return h | ||||
@@ -83,20 +92,21 @@ class _NLICls(nn.Module): | |||||
super(_NLICls, self).__init__() | super(_NLICls, self).__init__() | ||||
self.fc = nn.Sequential( | self.fc = nn.Sequential( | ||||
nn.Dropout(dropout), | nn.Dropout(dropout), | ||||
nn.Linear(in_dim*4, hid_dim), #4 | |||||
nn.Linear(in_dim * 4, hid_dim), # 4 | |||||
nn.LeakyReLU(), | nn.LeakyReLU(), | ||||
nn.Dropout(dropout), | nn.Dropout(dropout), | ||||
nn.Linear(hid_dim, num_cls), | nn.Linear(hid_dim, num_cls), | ||||
) | ) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1) | |||||
x = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], 1) | |||||
h = self.fc(x) | h = self.fc(x) | ||||
return h | return h | ||||
class STSeqLabel(nn.Module): | class STSeqLabel(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.models.STSeqLabel` :class:`fastNLP.models.start_transformer.STSeqLabel` | |||||
别名::class:`fastNLP.models.STSeqLabel` :class:`fastNLP.models.star_transformer.STSeqLabel` | |||||
用于序列标注的Star-Transformer模型 | 用于序列标注的Star-Transformer模型 | ||||
@@ -112,6 +122,7 @@ class STSeqLabel(nn.Module): | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | ||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, init_embed, num_cls, | def __init__(self, init_embed, num_cls, | ||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
@@ -120,7 +131,7 @@ class STSeqLabel(nn.Module): | |||||
max_len=512, | max_len=512, | ||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | |||||
dropout=0.1, ): | |||||
super(STSeqLabel, self).__init__() | super(STSeqLabel, self).__init__() | ||||
self.enc = StarTransEnc(init_embed=init_embed, | self.enc = StarTransEnc(init_embed=init_embed, | ||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
@@ -131,7 +142,7 @@ class STSeqLabel(nn.Module): | |||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | ||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | """ | ||||
@@ -142,9 +153,9 @@ class STSeqLabel(nn.Module): | |||||
mask = seq_len_to_mask(seq_len) | mask = seq_len_to_mask(seq_len) | ||||
nodes, _ = self.enc(words, mask) | nodes, _ = self.enc(words, mask) | ||||
output = self.cls(nodes) | output = self.cls(nodes) | ||||
output = output.transpose(1,2) # make hidden to be dim 1 | |||||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||||
output = output.transpose(1, 2) # make hidden to be dim 1 | |||||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | """ | ||||
@@ -159,7 +170,7 @@ class STSeqLabel(nn.Module): | |||||
class STSeqCls(nn.Module): | class STSeqCls(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.models.STSeqCls` :class:`fastNLP.models.start_transformer.STSeqCls` | |||||
别名::class:`fastNLP.models.STSeqCls` :class:`fastNLP.models.star_transformer.STSeqCls` | |||||
用于分类任务的Star-Transformer | 用于分类任务的Star-Transformer | ||||
@@ -175,7 +186,7 @@ class STSeqCls(nn.Module): | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | ||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, init_embed, num_cls, | def __init__(self, init_embed, num_cls, | ||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
@@ -184,7 +195,7 @@ class STSeqCls(nn.Module): | |||||
max_len=512, | max_len=512, | ||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | |||||
dropout=0.1, ): | |||||
super(STSeqCls, self).__init__() | super(STSeqCls, self).__init__() | ||||
self.enc = StarTransEnc(init_embed=init_embed, | self.enc = StarTransEnc(init_embed=init_embed, | ||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
@@ -195,7 +206,7 @@ class STSeqCls(nn.Module): | |||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | ||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | """ | ||||
@@ -206,9 +217,9 @@ class STSeqCls(nn.Module): | |||||
mask = seq_len_to_mask(seq_len) | mask = seq_len_to_mask(seq_len) | ||||
nodes, relay = self.enc(words, mask) | nodes, relay = self.enc(words, mask) | ||||
y = 0.5 * (relay + nodes.max(1)[0]) | y = 0.5 * (relay + nodes.max(1)[0]) | ||||
output = self.cls(y) # [bsz, n_cls] | |||||
output = self.cls(y) # [bsz, n_cls] | |||||
return {Const.OUTPUT: output} | return {Const.OUTPUT: output} | ||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | """ | ||||
@@ -223,7 +234,7 @@ class STSeqCls(nn.Module): | |||||
class STNLICls(nn.Module): | class STNLICls(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.models.STNLICls` :class:`fastNLP.models.start_transformer.STNLICls` | |||||
别名::class:`fastNLP.models.STNLICls` :class:`fastNLP.models.star_transformer.STNLICls` | |||||
用于自然语言推断(NLI)的Star-Transformer | 用于自然语言推断(NLI)的Star-Transformer | ||||
@@ -239,7 +250,7 @@ class STNLICls(nn.Module): | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | ||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, init_embed, num_cls, | def __init__(self, init_embed, num_cls, | ||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
@@ -248,7 +259,7 @@ class STNLICls(nn.Module): | |||||
max_len=512, | max_len=512, | ||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | |||||
dropout=0.1, ): | |||||
super(STNLICls, self).__init__() | super(STNLICls, self).__init__() | ||||
self.enc = StarTransEnc(init_embed=init_embed, | self.enc = StarTransEnc(init_embed=init_embed, | ||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
@@ -259,7 +270,7 @@ class STNLICls(nn.Module): | |||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | ||||
def forward(self, words1, words2, seq_len1, seq_len2): | def forward(self, words1, words2, seq_len1, seq_len2): | ||||
""" | """ | ||||
@@ -271,14 +282,16 @@ class STNLICls(nn.Module): | |||||
""" | """ | ||||
mask1 = seq_len_to_mask(seq_len1) | mask1 = seq_len_to_mask(seq_len1) | ||||
mask2 = seq_len_to_mask(seq_len2) | mask2 = seq_len_to_mask(seq_len2) | ||||
def enc(seq, mask): | def enc(seq, mask): | ||||
nodes, relay = self.enc(seq, mask) | nodes, relay = self.enc(seq, mask) | ||||
return 0.5 * (relay + nodes.max(1)[0]) | return 0.5 * (relay + nodes.max(1)[0]) | ||||
y1 = enc(words1, mask1) | y1 = enc(words1, mask1) | ||||
y2 = enc(words2, mask2) | y2 = enc(words2, mask2) | ||||
output = self.cls(y1, y2) # [bsz, n_cls] | |||||
output = self.cls(y1, y2) # [bsz, n_cls] | |||||
return {Const.OUTPUT: output} | return {Const.OUTPUT: output} | ||||
def predict(self, words1, words2, seq_len1, seq_len2): | def predict(self, words1, words2, seq_len1, seq_len2): | ||||
""" | """ | ||||
@@ -22,29 +22,35 @@ | |||||
+-----------------------+-----------------------+-----------------------+ | +-----------------------+-----------------------+-----------------------+ | ||||
""" | """ | ||||
from . import aggregator | |||||
from . import decoder | |||||
from . import encoder | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .dropout import TimestepDropout | |||||
from .encoder import * | |||||
from .utils import get_embeddings | |||||
__all__ = [ | __all__ = [ | ||||
"LSTM", | |||||
"Embedding", | |||||
# "BertModel", | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder", | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"BertModel", | |||||
"Embedding", | |||||
"LSTM", | |||||
"StarTransformer", | |||||
"TransformerEncoder", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU", | |||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"AvgPool", | "AvgPool", | ||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention", | |||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions", | "allowed_transitions", | ||||
] | |||||
] | |||||
from . import aggregator | |||||
from . import decoder | |||||
from . import encoder | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .dropout import TimestepDropout | |||||
from .encoder import * | |||||
from .utils import get_embeddings |
@@ -1,14 +1,14 @@ | |||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention, BiAttention | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"AvgPool", | "AvgPool", | ||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention" | |||||
] | ] | ||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -1,4 +1,7 @@ | |||||
__all__ =["MultiHeadAttention"] | |||||
__all__ = [ | |||||
"MultiHeadAttention" | |||||
] | |||||
import math | import math | ||||
import torch | import torch | ||||
@@ -15,6 +18,7 @@ class DotAttention(nn.Module): | |||||
.. todo:: | .. todo:: | ||||
补上文档 | 补上文档 | ||||
""" | """ | ||||
def __init__(self, key_size, value_size, dropout=0): | def __init__(self, key_size, value_size, dropout=0): | ||||
super(DotAttention, self).__init__() | super(DotAttention, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
@@ -22,7 +26,7 @@ class DotAttention(nn.Module): | |||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
self.softmax = nn.Softmax(dim=2) | self.softmax = nn.Softmax(dim=2) | ||||
def forward(self, Q, K, V, mask_out=None): | def forward(self, Q, K, V, mask_out=None): | ||||
""" | """ | ||||
@@ -41,6 +45,8 @@ class DotAttention(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | class MultiHeadAttention(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention` | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | ||||
:param key_size: int, 每个head的维度大小。 | :param key_size: int, 每个head的维度大小。 | ||||
@@ -48,13 +54,14 @@ class MultiHeadAttention(nn.Module): | |||||
:param num_head: int,head的数量。 | :param num_head: int,head的数量。 | ||||
:param dropout: float。 | :param dropout: float。 | ||||
""" | """ | ||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(MultiHeadAttention, self).__init__() | super(MultiHeadAttention, self).__init__() | ||||
self.input_size = input_size | self.input_size = input_size | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.num_head = num_head | self.num_head = num_head | ||||
in_size = key_size * num_head | in_size = key_size * num_head | ||||
self.q_in = nn.Linear(input_size, in_size) | self.q_in = nn.Linear(input_size, in_size) | ||||
self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
@@ -64,14 +71,14 @@ class MultiHeadAttention(nn.Module): | |||||
self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
self.drop = TimestepDropout(dropout) | self.drop = TimestepDropout(dropout) | ||||
self.reset_parameters() | self.reset_parameters() | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
sqrt = math.sqrt | sqrt = math.sqrt | ||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | ||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | ||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | ||||
nn.init.xavier_normal_(self.out.weight) | nn.init.xavier_normal_(self.out.weight) | ||||
def forward(self, Q, K, V, atte_mask_out=None): | def forward(self, Q, K, V, atte_mask_out=None): | ||||
""" | """ | ||||
@@ -87,7 +94,7 @@ class MultiHeadAttention(nn.Module): | |||||
q = self.q_in(Q).view(batch, sq, n_head, d_k) | q = self.q_in(Q).view(batch, sq, n_head, d_k) | ||||
k = self.k_in(K).view(batch, sk, n_head, d_k) | k = self.k_in(K).view(batch, sk, n_head, d_k) | ||||
v = self.v_in(V).view(batch, sk, n_head, d_v) | v = self.v_in(V).view(batch, sk, n_head, d_v) | ||||
# transpose q, k and v to do batch attention | # transpose q, k and v to do batch attention | ||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) | q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) | ||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) | k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) | ||||
@@ -95,7 +102,7 @@ class MultiHeadAttention(nn.Module): | |||||
if atte_mask_out is not None: | if atte_mask_out is not None: | ||||
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | ||||
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) | atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) | ||||
# concat all heads, do output linear | # concat all heads, do output linear | ||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | ||||
output = self.drop(self.out(atte)) | output = self.drop(self.out(atte)) | ||||
@@ -104,6 +111,10 @@ class MultiHeadAttention(nn.Module): | |||||
class BiAttention(nn.Module): | class BiAttention(nn.Module): | ||||
r"""Bi Attention module | r"""Bi Attention module | ||||
.. todo:: | |||||
这个模块的负责人来继续完善一下 | |||||
Calculate Bi Attention matrix `e` | Calculate Bi Attention matrix `e` | ||||
.. math:: | .. math:: | ||||
@@ -115,11 +126,11 @@ class BiAttention(nn.Module): | |||||
\end{array} | \end{array} | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(BiAttention, self).__init__() | super(BiAttention, self).__init__() | ||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, in_x1, in_x2, x1_len, x2_len): | def forward(self, in_x1, in_x2, x1_len, x2_len): | ||||
""" | """ | ||||
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | ||||
@@ -130,36 +141,36 @@ class BiAttention(nn.Module): | |||||
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | ||||
""" | """ | ||||
assert in_x1.size()[0] == in_x2.size()[0] | assert in_x1.size()[0] == in_x2.size()[0] | ||||
assert in_x1.size()[2] == in_x2.size()[2] | assert in_x1.size()[2] == in_x2.size()[2] | ||||
# The batch size and hidden size must be equal. | # The batch size and hidden size must be equal. | ||||
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | ||||
# The seq len in in_x and x_len must be equal. | # The seq len in in_x and x_len must be equal. | ||||
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | ||||
batch_size = in_x1.size()[0] | batch_size = in_x1.size()[0] | ||||
x1_max_len = in_x1.size()[1] | x1_max_len = in_x1.size()[1] | ||||
x2_max_len = in_x2.size()[1] | x2_max_len = in_x2.size()[1] | ||||
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | ||||
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | ||||
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] | a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] | ||||
a_mask = a_mask.view(batch_size, x1_max_len, -1) | a_mask = a_mask.view(batch_size, x1_max_len, -1) | ||||
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | ||||
b_mask = x2_len.le(0.5).float() * -self.inf | b_mask = x2_len.le(0.5).float() * -self.inf | ||||
b_mask = b_mask.view(batch_size, -1, x2_max_len) | b_mask = b_mask.view(batch_size, -1, x2_max_len) | ||||
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | ||||
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | ||||
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | ||||
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | ||||
attention_b_t = torch.transpose(attention_b, 1, 2) | attention_b_t = torch.transpose(attention_b, 1, 2) | ||||
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | ||||
return out_x1, out_x2 | return out_x1, out_x2 | ||||
@@ -173,10 +184,10 @@ class SelfAttention(nn.Module): | |||||
:param float drop: dropout概率,默认值为0.5 | :param float drop: dropout概率,默认值为0.5 | ||||
:param str initial_method: 初始化参数方法 | :param str initial_method: 初始化参数方法 | ||||
""" | """ | ||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None,): | |||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): | |||||
super(SelfAttention, self).__init__() | super(SelfAttention, self).__init__() | ||||
self.attention_hops = attention_hops | self.attention_hops = attention_hops | ||||
self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | ||||
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | ||||
@@ -185,7 +196,7 @@ class SelfAttention(nn.Module): | |||||
self.drop = nn.Dropout(drop) | self.drop = nn.Dropout(drop) | ||||
self.tanh = nn.Tanh() | self.tanh = nn.Tanh() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _penalization(self, attention): | def _penalization(self, attention): | ||||
""" | """ | ||||
compute the penalization term for attention module | compute the penalization term for attention module | ||||
@@ -199,7 +210,7 @@ class SelfAttention(nn.Module): | |||||
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | ||||
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | ||||
return torch.sum(ret) / size[0] | return torch.sum(ret) / size[0] | ||||
def forward(self, input, input_origin): | def forward(self, input, input_origin): | ||||
""" | """ | ||||
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 | :param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 | ||||
@@ -209,15 +220,14 @@ class SelfAttention(nn.Module): | |||||
""" | """ | ||||
input = input.contiguous() | input = input.contiguous() | ||||
size = input.size() # [bsz, len, nhid] | size = input.size() # [bsz, len, nhid] | ||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | ||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | ||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | ||||
attention = self.ws2(y1).transpose(1, 2).contiguous() | attention = self.ws2(y1).transpose(1, 2).contiguous() | ||||
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | ||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | ||||
attention = F.softmax(attention, 2) # [baz ,hop, len] | attention = F.softmax(attention, 2) # [baz ,hop, len] | ||||
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] | return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] | ||||
@@ -1,4 +1,8 @@ | |||||
__all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"] | |||||
__all__ = [ | |||||
"MaxPool", | |||||
"MaxPoolWithMask", | |||||
"AvgPool" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -16,6 +20,7 @@ class MaxPool(nn.Module): | |||||
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | :param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | ||||
:param ceil_mode: | :param ceil_mode: | ||||
""" | """ | ||||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | ||||
super(MaxPool, self).__init__() | super(MaxPool, self).__init__() | ||||
@@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module): | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | ||||
的时候只会考虑mask为1的位置 | 的时候只会考虑mask为1的位置 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(AvgPoolWithMask, self).__init__() | super(AvgPoolWithMask, self).__init__() | ||||
self.inf = 10e12 | self.inf = 10e12 | ||||
@@ -1,11 +1,11 @@ | |||||
from .CRF import ConditionalRandomField | |||||
from .MLP import MLP | |||||
from .utils import viterbi_decode | |||||
from .CRF import allowed_transitions | |||||
__all__ = [ | __all__ = [ | ||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions" | "allowed_transitions" | ||||
] | ] | ||||
from .crf import ConditionalRandomField | |||||
from .mlp import MLP | |||||
from .utils import viterbi_decode | |||||
from .crf import allowed_transitions |
@@ -1,3 +1,8 @@ | |||||
__all__ = [ | |||||
"ConditionalRandomField", | |||||
"allowed_transitions" | |||||
] | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -6,7 +11,7 @@ from ..utils import initial_parameter | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.CRF.allowed_transitions` | |||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | |||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
@@ -15,8 +20,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | :param str encoding_type: 支持"bio", "bmes", "bmeso"。 | ||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
为False, 返回的结果中不含与开始结尾相关的内容 | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | ||||
""" | """ | ||||
num_tags = len(id2target) | num_tags = len(id2target) | ||||
@@ -27,6 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
id_label_lst = list(id2target.items()) | id_label_lst = list(id2target.items()) | ||||
if include_start_end: | if include_start_end: | ||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | ||||
def split_tag_label(from_label): | def split_tag_label(from_label): | ||||
from_label = from_label.lower() | from_label = from_label.lower() | ||||
if from_label in ['start', 'end']: | if from_label in ['start', 'end']: | ||||
@@ -36,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
from_tag = from_label[:1] | from_tag = from_label[:1] | ||||
from_label = from_label[2:] | from_label = from_label[2:] | ||||
return from_tag, from_label | return from_tag, from_label | ||||
for from_id, from_label in id_label_lst: | for from_id, from_label in id_label_lst: | ||||
if from_label in ['<pad>', '<unk>']: | if from_label in ['<pad>', '<unk>']: | ||||
continue | continue | ||||
@@ -60,7 +65,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
:param str to_label: 比如"PER", "LOC"等label | :param str to_label: 比如"PER", "LOC"等label | ||||
:return: bool,能否跃迁 | :return: bool,能否跃迁 | ||||
""" | """ | ||||
if to_tag=='start' or from_tag=='end': | |||||
if to_tag == 'start' or from_tag == 'end': | |||||
return False | return False | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
@@ -83,12 +88,12 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ('b', 'o') | return to_tag in ('b', 'o') | ||||
elif from_tag in ['b', 'i']: | elif from_tag in ['b', 'i']: | ||||
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||||
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) | |||||
elif from_tag == 'o': | elif from_tag == 'o': | ||||
return to_tag in ['end', 'b', 'o'] | return to_tag in ['end', 'b', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | ||||
elif encoding_type == 'bmes': | elif encoding_type == 'bmes': | ||||
""" | """ | ||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | ||||
@@ -111,9 +116,9 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ['b', 's'] | return to_tag in ['b', 's'] | ||||
elif from_tag == 'b': | elif from_tag == 'b': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | elif from_tag == 'm': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's']: | elif from_tag in ['e', 's']: | ||||
return to_tag in ['b', 's', 'end'] | return to_tag in ['b', 's', 'end'] | ||||
else: | else: | ||||
@@ -122,21 +127,21 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ['b', 's', 'o'] | return to_tag in ['b', 's', 'o'] | ||||
elif from_tag == 'b': | elif from_tag == 'b': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | elif from_tag == 'm': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | elif from_tag in ['e', 's', 'o']: | ||||
return to_tag in ['b', 's', 'end', 'o'] | return to_tag in ['b', 's', 'end', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | ||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | ||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.CRF.ConditionalRandomField` | |||||
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField` | |||||
条件随机场。 | 条件随机场。 | ||||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | ||||
@@ -148,30 +153,31 @@ class ConditionalRandomField(nn.Module): | |||||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | ||||
:param str initial_method: 初始化方法。见initial_parameter | :param str initial_method: 初始化方法。见initial_parameter | ||||
""" | """ | ||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | ||||
initial_method=None): | initial_method=None): | ||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
self.num_tags = num_tags | self.num_tags = num_tags | ||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | self.start_scores = nn.Parameter(torch.randn(num_tags)) | ||||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | self.end_scores = nn.Parameter(torch.randn(num_tags)) | ||||
if allowed_transitions is None: | if allowed_transitions is None: | ||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | constrain = torch.zeros(num_tags + 2, num_tags + 2) | ||||
else: | else: | ||||
constrain = torch.full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||||
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) | |||||
for from_tag_id, to_tag_id in allowed_transitions: | for from_tag_id, to_tag_id in allowed_transitions: | ||||
constrain[from_tag_id, to_tag_id] = 0 | constrain[from_tag_id, to_tag_id] = 0 | ||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | self._constrain = nn.Parameter(constrain, requires_grad=False) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _normalizer_likelihood(self, logits, mask): | def _normalizer_likelihood(self, logits, mask): | ||||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | """Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
@@ -184,21 +190,21 @@ class ConditionalRandomField(nn.Module): | |||||
alpha = logits[0] | alpha = logits[0] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.start_scores.view(1, -1) | alpha = alpha + self.start_scores.view(1, -1) | ||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
emit_score = logits[i].view(batch_size, 1, n_tags) | emit_score = logits[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | ||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | alpha = alpha + self.end_scores.view(1, -1) | ||||
return torch.logsumexp(alpha, 1) | return torch.logsumexp(alpha, 1) | ||||
def _gold_score(self, logits, tags, mask): | def _gold_score(self, logits, tags, mask): | ||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
@@ -210,15 +216,15 @@ class ConditionalRandomField(nn.Module): | |||||
seq_len, batch_size, _ = logits.size() | seq_len, batch_size, _ = logits.size() | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
mask = mask.byte() | mask = mask.byte() | ||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
# emit_score [L, B] | # emit_score [L, B] | ||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) | |||||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | |||||
# score [L-1, B] | # score [L-1, B] | ||||
score = trans_score + emit_score[:seq_len-1, :] | |||||
score = trans_score + emit_score[:seq_len - 1, :] | |||||
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
@@ -227,24 +233,24 @@ class ConditionalRandomField(nn.Module): | |||||
score = score + st_scores + ed_scores | score = score + st_scores + ed_scores | ||||
# return [B,] | # return [B,] | ||||
return score | return score | ||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
""" | """ | ||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | ||||
:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | :param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | ||||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | ||||
:return:torch.FloatTensor, (batch_size,) | |||||
:return: torch.FloatTensor, (batch_size,) | |||||
""" | """ | ||||
feats = feats.transpose(0, 1) | feats = feats.transpose(0, 1) | ||||
tags = tags.transpose(0, 1).long() | tags = tags.transpose(0, 1).long() | ||||
mask = mask.transpose(0, 1).float() | mask = mask.transpose(0, 1).float() | ||||
all_path_score = self._normalizer_likelihood(feats, mask) | all_path_score = self._normalizer_likelihood(feats, mask) | ||||
gold_path_score = self._gold_score(feats, tags, mask) | gold_path_score = self._gold_score(feats, tags, mask) | ||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, logits, mask, unpad=False): | def viterbi_decode(self, logits, mask, unpad=False): | ||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
@@ -259,9 +265,9 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = logits[0] | vscore = logits[0] | ||||
@@ -269,8 +275,8 @@ class ConditionalRandomField(nn.Module): | |||||
transitions[:n_tags, :n_tags] += self.trans_m.data | transitions[:n_tags, :n_tags] += self.trans_m.data | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
transitions[n_tags, :n_tags] += self.start_scores.data | transitions[n_tags, :n_tags] += self.start_scores.data | ||||
transitions[:n_tags, n_tags+1] += self.end_scores.data | |||||
transitions[:n_tags, n_tags + 1] += self.end_scores.data | |||||
vscore += transitions[n_tags, :n_tags] | vscore += transitions[n_tags, :n_tags] | ||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
@@ -280,30 +286,29 @@ class ConditionalRandomField(nn.Module): | |||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | ||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) | |||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | last_tags = vpath[idxes[i], batch_idx, last_tags] | ||||
ans[idxes[i+1], batch_idx] = last_tags | |||||
ans[idxes[i + 1], batch_idx] = last_tags | |||||
ans = ans.transpose(0, 1) | ans = ans.transpose(0, 1) | ||||
if unpad: | if unpad: | ||||
paths = [] | paths = [] | ||||
for idx, seq_len in enumerate(lens): | for idx, seq_len in enumerate(lens): | ||||
paths.append(ans[idx, :seq_len+1].tolist()) | |||||
paths.append(ans[idx, :seq_len + 1].tolist()) | |||||
else: | else: | ||||
paths = ans | paths = ans | ||||
return paths, ans_score | return paths, ans_score | ||||
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"MLP" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -6,17 +10,16 @@ from ..utils import initial_parameter | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP.MLP` | |||||
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP` | |||||
多层感知器 | 多层感知器 | ||||
:param list size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||||
:param str or list activation: | |||||
一个字符串或者函数或者字符串跟函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param str or function output_activation : 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||||
:param str initial_method: 参数初始化方式 | :param str initial_method: 参数初始化方式 | ||||
:param float dropout: dropout概率,默认值为0 | :param float dropout: dropout概率,默认值为0 | ||||
.. note:: | .. note:: | ||||
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | ||||
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | ||||
@@ -35,10 +38,8 @@ class MLP(nn.Module): | |||||
>>> y = net(x) | >>> y = net(x) | ||||
>>> print(x) | >>> print(x) | ||||
>>> print(y) | >>> print(y) | ||||
>>> | |||||
""" | """ | ||||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | ||||
super(MLP, self).__init__() | super(MLP, self).__init__() | ||||
self.hiddens = nn.ModuleList() | self.hiddens = nn.ModuleList() | ||||
@@ -46,12 +47,12 @@ class MLP(nn.Module): | |||||
self.output_activation = output_activation | self.output_activation = output_activation | ||||
for i in range(1, len(size_layer)): | for i in range(1, len(size_layer)): | ||||
if i + 1 == len(size_layer): | if i + 1 == len(size_layer): | ||||
self.output = nn.Linear(size_layer[i-1], size_layer[i]) | |||||
self.output = nn.Linear(size_layer[i - 1], size_layer[i]) | |||||
else: | else: | ||||
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i])) | |||||
self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i])) | |||||
self.dropout = nn.Dropout(p=dropout) | self.dropout = nn.Dropout(p=dropout) | ||||
actives = { | actives = { | ||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
@@ -80,7 +81,7 @@ class MLP(nn.Module): | |||||
else: | else: | ||||
raise ValueError("should set activation correctly: {}".format(activation)) | raise ValueError("should set activation correctly: {}".format(activation)) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: MLP接受的输入 | :param torch.Tensor x: MLP接受的输入 | ||||
@@ -93,16 +94,3 @@ class MLP(nn.Module): | |||||
x = self.output_activation(x) | x = self.output_activation(x) | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
return x | return x | ||||
if __name__ == '__main__': | |||||
net1 = MLP([5, 10, 5]) | |||||
net2 = MLP([5, 10, 5], 'tanh') | |||||
net3 = MLP([5, 6, 7, 8, 5], 'tanh') | |||||
net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') | |||||
net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') | |||||
for net in [net1, net2, net3, net4, net5]: | |||||
x = torch.randn(5, 5) | |||||
y = net(x) | |||||
print(x) | |||||
print(y) |
@@ -1,10 +1,12 @@ | |||||
__all__ = ["viterbi_decode"] | |||||
__all__ = [ | |||||
"viterbi_decode" | |||||
] | |||||
import torch | import torch | ||||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | def viterbi_decode(logits, transitions, mask=None, unpad=False): | ||||
""" | |||||
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode | |||||
r""" | |||||
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode` | |||||
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
@@ -20,18 +22,19 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||||
"compatible." | |||||
assert n_tags == transitions.size(0) and n_tags == transitions.size( | |||||
1), "The shapes of transitions and feats are not " \ | |||||
"compatible." | |||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | mask = mask.transpose(0, 1).data.byte() # L, B | ||||
else: | else: | ||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | ||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = logits[0] | vscore = logits[0] | ||||
trans_score = transitions.view(1, n_tags, n_tags).data | trans_score = transitions.view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
@@ -41,14 +44,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | ||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | vscore.masked_fill(mask[i].view(batch_size, 1), 0) | ||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | ||||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
@@ -62,4 +65,4 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
paths.append(ans[idx, :seq_len + 1].tolist()) | paths.append(ans[idx, :seq_len + 1].tolist()) | ||||
else: | else: | ||||
paths = ans | paths = ans | ||||
return paths, ans_score | |||||
return paths, ans_score |
@@ -1,6 +1,8 @@ | |||||
import torch | |||||
__all__ = [] | __all__ = [] | ||||
import torch | |||||
class TimestepDropout(torch.nn.Dropout): | class TimestepDropout(torch.nn.Dropout): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.TimestepDropout` | 别名::class:`fastNLP.modules.TimestepDropout` | ||||
@@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout): | |||||
接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | 接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | ||||
在每个timestamp上做dropout。 | 在每个timestamp上做dropout。 | ||||
""" | """ | ||||
def forward(self, x): | def forward(self, x): | ||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | ||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | ||||
@@ -1,11 +1,28 @@ | |||||
from .conv_maxpool import ConvMaxpool | |||||
from .embedding import Embedding | |||||
from .lstm import LSTM | |||||
from .bert import BertModel | |||||
__all__ = [ | __all__ = [ | ||||
"LSTM", | |||||
"Embedding", | |||||
# "BertModel", | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder", | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"BertModel" | |||||
"Embedding", | |||||
"LSTM", | |||||
"StarTransformer", | |||||
"TransformerEncoder", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | ] | ||||
from .bert import BertModel | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||||
from .conv_maxpool import ConvMaxpool | |||||
from .embedding import Embedding | |||||
from .lstm import LSTM | |||||
from .star_transformer import StarTransformer | |||||
from .transformer import TransformerEncoder | |||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -1,5 +1,9 @@ | |||||
__all__ = [ | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder" | |||||
] | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
@@ -10,20 +14,22 @@ class ConvolutionCharEncoder(nn.Module): | |||||
别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder` | 别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder` | ||||
char级别的卷积编码器. | char级别的卷积编码器. | ||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||||
:例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||||
:param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. | :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. | ||||
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. | :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. | ||||
:param initial_method: 初始化参数的方式, 默认为`xavier normal` | :param initial_method: 初始化参数的方式, 默认为`xavier normal` | ||||
""" | """ | ||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | ||||
super(ConvolutionCharEncoder, self).__init__() | super(ConvolutionCharEncoder, self).__init__() | ||||
self.convs = nn.ModuleList([ | self.convs = nn.ModuleList([ | ||||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | ||||
for i in range(len(kernels))]) | for i in range(len(kernels))]) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | ||||
@@ -34,7 +40,7 @@ class ConvolutionCharEncoder(nn.Module): | |||||
x = x.transpose(2, 3) | x = x.transpose(2, 3) | ||||
# [batch_size*sent_length, channel, height, width] | # [batch_size*sent_length, channel, height, width] | ||||
return self._convolute(x).unsqueeze(2) | return self._convolute(x).unsqueeze(2) | ||||
def _convolute(self, x): | def _convolute(self, x): | ||||
feats = [] | feats = [] | ||||
for conv in self.convs: | for conv in self.convs: | ||||
@@ -50,7 +56,14 @@ class ConvolutionCharEncoder(nn.Module): | |||||
class LSTMCharEncoder(nn.Module): | class LSTMCharEncoder(nn.Module): | ||||
"""char级别基于LSTM的encoder.""" | |||||
""" | |||||
别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.LSTMCharEncoder` | |||||
char级别基于LSTM的encoder. | |||||
""" | |||||
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | ||||
""" | """ | ||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
@@ -60,14 +73,14 @@ class LSTMCharEncoder(nn.Module): | |||||
""" | """ | ||||
super(LSTMCharEncoder, self).__init__() | super(LSTMCharEncoder, self).__init__() | ||||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size | self.hidden_size = char_emb_size if hidden_size is None else hidden_size | ||||
self.lstm = nn.LSTM(input_size=char_emb_size, | self.lstm = nn.LSTM(input_size=char_emb_size, | ||||
hidden_size=self.hidden_size, | hidden_size=self.hidden_size, | ||||
num_layers=1, | num_layers=1, | ||||
bias=True, | bias=True, | ||||
batch_first=True) | batch_first=True) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | ||||
@@ -78,6 +91,6 @@ class LSTMCharEncoder(nn.Module): | |||||
h0 = nn.init.orthogonal_(h0) | h0 = nn.init.orthogonal_(h0) | ||||
c0 = torch.empty(1, batch_size, self.hidden_size) | c0 = torch.empty(1, batch_size, self.hidden_size) | ||||
c0 = nn.init.orthogonal_(c0) | c0 = nn.init.orthogonal_(c0) | ||||
_, hidden = self.lstm(x, (h0, c0)) | _, hidden = self.lstm(x, (h0, c0)) | ||||
return hidden[0].squeeze().unsqueeze(2) | return hidden[0].squeeze().unsqueeze(2) |
@@ -1,6 +1,6 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
__all__ = [ | |||||
"ConvMaxpool" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
@@ -27,22 +27,24 @@ class ConvMaxpool(nn.Module): | |||||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | ||||
:param str initial_method: str。 | :param str initial_method: str。 | ||||
""" | """ | ||||
def __init__(self, in_channels, out_channels, kernel_sizes, | def __init__(self, in_channels, out_channels, kernel_sizes, | ||||
stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
groups=1, bias=True, activation="relu", initial_method=None): | groups=1, bias=True, activation="relu", initial_method=None): | ||||
super(ConvMaxpool, self).__init__() | super(ConvMaxpool, self).__init__() | ||||
# convolution | # convolution | ||||
if isinstance(kernel_sizes, (list, tuple, int)): | if isinstance(kernel_sizes, (list, tuple, int)): | ||||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | ||||
out_channels = [out_channels] | out_channels = [out_channels] | ||||
kernel_sizes = [kernel_sizes] | kernel_sizes = [kernel_sizes] | ||||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | ||||
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
assert len(out_channels) == len( | |||||
kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
else: | else: | ||||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | raise ValueError("The type of out_channels and kernel_sizes should be the same.") | ||||
self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
in_channels=in_channels, | in_channels=in_channels, | ||||
out_channels=oc, | out_channels=oc, | ||||
@@ -53,11 +55,11 @@ class ConvMaxpool(nn.Module): | |||||
groups=groups, | groups=groups, | ||||
bias=bias) | bias=bias) | ||||
for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||
# activation function | # activation function | ||||
if activation == 'relu': | if activation == 'relu': | ||||
self.activation = F.relu | self.activation = F.relu | ||||
@@ -68,9 +70,9 @@ class ConvMaxpool(nn.Module): | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
"Undefined activation function: choose from: relu, tanh, sigmoid") | "Undefined activation function: choose from: relu, tanh, sigmoid") | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x, mask=None): | def forward(self, x, mask=None): | ||||
""" | """ | ||||
@@ -83,9 +85,9 @@ class ConvMaxpool(nn.Module): | |||||
# convolution | # convolution | ||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | ||||
# max-pooling | # max-pooling | ||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | ||||
for i in xs] # [[N, C], ...] | for i in xs] # [[N, C], ...] | ||||
return torch.cat(xs, dim=-1) # [N, C] | |||||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -1,14 +1,18 @@ | |||||
__all__ = [ | |||||
"Embedding" | |||||
] | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..utils import get_embeddings | from ..utils import get_embeddings | ||||
class Embedding(nn.Embedding): | class Embedding(nn.Embedding): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` | 别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` | ||||
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | ||||
def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, | def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, | ||||
scale_grad_by_freq=False): | |||||
scale_grad_by_freq=False): | |||||
""" | """ | ||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | ||||
@@ -22,14 +26,14 @@ class Embedding(nn.Embedding): | |||||
""" | """ | ||||
embed = get_embeddings(init_embed) | embed = get_embeddings(init_embed) | ||||
num_embeddings, embedding_dim = embed.weight.size() | num_embeddings, embedding_dim = embed.weight.size() | ||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, | super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, | ||||
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse, _weight=embed.weight.data) | |||||
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse, _weight=embed.weight.data) | |||||
del embed | del embed | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.LongTensor x: [batch, seq_len] | :param torch.LongTensor x: [batch, seq_len] | ||||
@@ -1,6 +1,11 @@ | |||||
"""轻量封装的 Pytorch LSTM 模块. | |||||
""" | |||||
轻量封装的 Pytorch LSTM 模块. | |||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | ||||
""" | """ | ||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.utils.rnn as rnn | import torch.nn.utils.rnn as rnn | ||||
@@ -23,6 +28,7 @@ class LSTM(nn.Module): | |||||
:(batch, seq, feature). Default: ``False`` | :(batch, seq, feature). Default: ``False`` | ||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | ||||
""" | """ | ||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | ||||
bidirectional=False, bias=True, initial_method=None): | bidirectional=False, bias=True, initial_method=None): | ||||
super(LSTM, self).__init__() | super(LSTM, self).__init__() | ||||
@@ -30,7 +36,7 @@ class LSTM(nn.Module): | |||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | ||||
dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x, seq_len=None, h0=None, c0=None): | def forward(self, x, seq_len=None, h0=None, c0=None): | ||||
""" | """ | ||||
@@ -1,9 +1,14 @@ | |||||
"""Star-Transformer 的encoder部分的 Pytorch 实现 | |||||
""" | """ | ||||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||||
""" | |||||
__all__ = [ | |||||
"StarTransformer" | |||||
] | |||||
import numpy as NP | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
import numpy as NP | |||||
class StarTransformer(nn.Module): | class StarTransformer(nn.Module): | ||||
@@ -24,10 +29,11 @@ class StarTransformer(nn.Module): | |||||
模型会为输入序列加上position embedding。 | 模型会为输入序列加上position embedding。 | ||||
若为`None`,忽略加上position embedding的步骤. Default: `None` | 若为`None`,忽略加上position embedding的步骤. Default: `None` | ||||
""" | """ | ||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | ||||
super(StarTransformer, self).__init__() | super(StarTransformer, self).__init__() | ||||
self.iters = num_layers | self.iters = num_layers | ||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | ||||
self.ring_att = nn.ModuleList( | self.ring_att = nn.ModuleList( | ||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | ||||
@@ -35,12 +41,12 @@ class StarTransformer(nn.Module): | |||||
self.star_att = nn.ModuleList( | self.star_att = nn.ModuleList( | ||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | ||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
if max_len is not None: | if max_len is not None: | ||||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | ||||
else: | else: | ||||
self.pos_emb = None | self.pos_emb = None | ||||
def forward(self, data, mask): | def forward(self, data, mask): | ||||
""" | """ | ||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | :param FloatTensor data: [batch, length, hidden] 输入的序列 | ||||
@@ -50,20 +56,21 @@ class StarTransformer(nn.Module): | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | [batch, hidden] 全局 relay 节点, 详见论文 | ||||
""" | """ | ||||
def norm_func(f, x): | def norm_func(f, x): | ||||
# B, H, L, 1 | # B, H, L, 1 | ||||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | ||||
B, L, H = data.size() | B, L, H = data.size() | ||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | ||||
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 | |||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||||
if self.pos_emb: | if self.pos_emb: | ||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device)\ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
embs = embs + P | embs = embs + P | ||||
nodes = embs | nodes = embs | ||||
relay = embs.mean(2, keepdim=True) | relay = embs.mean(2, keepdim=True) | ||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ||||
@@ -72,11 +79,11 @@ class StarTransformer(nn.Module): | |||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ||||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | ||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | ||||
nodes = nodes.masked_fill_(ex_mask, 0) | nodes = nodes.masked_fill_(ex_mask, 0) | ||||
nodes = nodes.view(B, H, L).permute(0, 2, 1) | nodes = nodes.view(B, H, L).permute(0, 2, 1) | ||||
return nodes, relay.view(B, H) | return nodes, relay.view(B, H) | ||||
@@ -89,37 +96,37 @@ class _MSA1(nn.Module): | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | ||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | ||||
def forward(self, x, ax=None): | def forward(self, x, ax=None): | ||||
# x: B, H, L, 1, ax : B, H, X, L append features | # x: B, H, L, 1, ax : B, H, X, L append features | ||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | ||||
B, H, L, _ = x.shape | B, H, L, _ = x.shape | ||||
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | ||||
if ax is not None: | if ax is not None: | ||||
aL = ax.shape[2] | aL = ax.shape[2] | ||||
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | ||||
av = self.WV(ax).view(B, nhead, head_dim, aL, L) | av = self.WV(ax).view(B, nhead, head_dim, aL, L) | ||||
q = q.view(B, nhead, head_dim, 1, L) | q = q.view(B, nhead, head_dim, 1, L) | ||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
if ax is not None: | if ax is not None: | ||||
k = torch.cat([k, ak], 3) | k = torch.cat([k, ak], 3) | ||||
v = torch.cat([v, av], 3) | v = torch.cat([v, av], 3) | ||||
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U | alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U | ||||
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | ||||
ret = self.WO(att) | ret = self.WO(att) | ||||
return ret | return ret | ||||
@@ -131,19 +138,19 @@ class _MSA2(nn.Module): | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | ||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | ||||
def forward(self, x, y, mask=None): | def forward(self, x, y, mask=None): | ||||
# x: B, H, 1, 1, 1 y: B H L 1 | # x: B, H, 1, 1, 1 y: B H L 1 | ||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | ||||
B, H, L, _ = y.shape | B, H, L, _ = y.shape | ||||
q, k, v = self.WQ(x), self.WK(y), self.WV(y) | q, k, v = self.WQ(x), self.WK(y), self.WV(y) | ||||
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | ||||
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | ||||
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | ||||
@@ -1,3 +1,6 @@ | |||||
__all__ = [ | |||||
"TransformerEncoder" | |||||
] | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAttention | from ..aggregator.attention import MultiHeadAttention | ||||
@@ -19,6 +22,7 @@ class TransformerEncoder(nn.Module): | |||||
:param int num_head: head的数量。 | :param int num_head: head的数量。 | ||||
:param float dropout: dropout概率. Default: 0.1 | :param float dropout: dropout概率. Default: 0.1 | ||||
""" | """ | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
@@ -27,9 +31,9 @@ class TransformerEncoder(nn.Module): | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | ||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Linear(inner_size, model_size), | nn.Linear(inner_size, model_size), | ||||
TimestepDropout(dropout),) | |||||
TimestepDropout(dropout), ) | |||||
self.norm2 = nn.LayerNorm(model_size) | self.norm2 = nn.LayerNorm(model_size) | ||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | def forward(self, input, seq_mask=None, atte_mask_out=None): | ||||
""" | """ | ||||
@@ -44,11 +48,11 @@ class TransformerEncoder(nn.Module): | |||||
output = self.norm2(output + norm_atte) | output = self.norm2(output + norm_atte) | ||||
output *= seq_mask | output *= seq_mask | ||||
return output | return output | ||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | ||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
""" | """ | ||||
:param x: [batch, seq_len, model_size] 输入序列 | :param x: [batch, seq_len, model_size] 输入序列 | ||||
@@ -60,8 +64,8 @@ class TransformerEncoder(nn.Module): | |||||
if seq_mask is None: | if seq_mask is None: | ||||
atte_mask_out = None | atte_mask_out = None | ||||
else: | else: | ||||
atte_mask_out = (seq_mask < 1)[:,None,:] | |||||
seq_mask = seq_mask[:,:,None] | |||||
atte_mask_out = (seq_mask < 1)[:, None, :] | |||||
seq_mask = seq_mask[:, :, None] | |||||
for layer in self.layers: | for layer in self.layers: | ||||
output = layer(output, seq_mask, atte_mask_out) | output = layer(output, seq_mask, atte_mask_out) | ||||
return output | return output |
@@ -1,9 +1,15 @@ | |||||
"""Variational RNN 的 Pytorch 实现 | |||||
""" | """ | ||||
Variational RNN 的 Pytorch 实现 | |||||
""" | |||||
__all__ = [ | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | ||||
from ..utils import initial_parameter | |||||
try: | try: | ||||
from torch import flip | from torch import flip | ||||
@@ -11,21 +17,25 @@ except ImportError: | |||||
def flip(x, dims): | def flip(x, dims): | ||||
indices = [slice(None)] * x.dim() | indices = [slice(None)] * x.dim() | ||||
for dim in dims: | for dim in dims: | ||||
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||||
indices[dim] = torch.arange( | |||||
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||||
return x[tuple(indices)] | return x[tuple(indices)] | ||||
from ..utils import initial_parameter | |||||
class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
"""Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | """ | ||||
Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | |||||
def __init__(self, cell, hidden_size, input_p, hidden_p): | def __init__(self, cell, hidden_size, input_p, hidden_p): | ||||
super(VarRnnCellWrapper, self).__init__() | super(VarRnnCellWrapper, self).__init__() | ||||
self.cell = cell | self.cell = cell | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | ||||
""" | """ | ||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | :param PackedSequence input_x: [seq_len, batch_size, input_size] | ||||
@@ -37,11 +47,13 @@ class VarRnnCellWrapper(nn.Module): | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | ||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
def get_hi(hi, h0, size): | def get_hi(hi, h0, size): | ||||
h0_size = size - hi.size(0) | h0_size = size - hi.size(0) | ||||
if h0_size > 0: | if h0_size > 0: | ||||
return torch.cat([hi, h0[:h0_size]], dim=0) | return torch.cat([hi, h0[:h0_size]], dim=0) | ||||
return hi[:size] | return hi[:size] | ||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input, batch_sizes = input_x.data, input_x.batch_sizes | input, batch_sizes = input_x.data, input_x.batch_sizes | ||||
output = [] | output = [] | ||||
@@ -52,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
else: | else: | ||||
batch_iter = batch_sizes | batch_iter = batch_sizes | ||||
idx = 0 | idx = 0 | ||||
if is_lstm: | if is_lstm: | ||||
hn = (hidden[0].clone(), hidden[1].clone()) | hn = (hidden[0].clone(), hidden[1].clone()) | ||||
else: | else: | ||||
@@ -60,15 +72,16 @@ class VarRnnCellWrapper(nn.Module): | |||||
hi = hidden | hi = hidden | ||||
for size in batch_iter: | for size in batch_iter: | ||||
if is_reversed: | if is_reversed: | ||||
input_i = input[idx-size: idx] * mask_x[:size] | |||||
input_i = input[idx - size: idx] * mask_x[:size] | |||||
idx -= size | idx -= size | ||||
else: | else: | ||||
input_i = input[idx: idx+size] * mask_x[:size] | |||||
input_i = input[idx: idx + size] * mask_x[:size] | |||||
idx += size | idx += size | ||||
mask_hi = mask_h[:size] | mask_hi = mask_h[:size] | ||||
if is_lstm: | if is_lstm: | ||||
hx, cx = hi | hx, cx = hi | ||||
hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) | |||||
hi = (get_hi(hx, hidden[0], size) * | |||||
mask_hi, get_hi(cx, hidden[1], size)) | |||||
hi = cell(input_i, hi) | hi = cell(input_i, hi) | ||||
hn[0][:size] = hi[0] | hn[0][:size] = hi[0] | ||||
hn[1][:size] = hi[1] | hn[1][:size] = hi[1] | ||||
@@ -78,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
hi = cell(input_i, hi) | hi = cell(input_i, hi) | ||||
hn[:size] = hi | hn[:size] = hi | ||||
output.append(hi) | output.append(hi) | ||||
if is_reversed: | if is_reversed: | ||||
output = list(reversed(output)) | output = list(reversed(output)) | ||||
output = torch.cat(output, dim=0) | output = torch.cat(output, dim=0) | ||||
@@ -86,7 +99,9 @@ class VarRnnCellWrapper(nn.Module): | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
"""Variational Dropout RNN 实现. | |||||
""" | |||||
Variational Dropout RNN 实现. | |||||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | ||||
https://arxiv.org/abs/1512.05287`. | https://arxiv.org/abs/1512.05287`. | ||||
@@ -102,7 +117,7 @@ class VarRNNBase(nn.Module): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
bias=True, batch_first=False, | bias=True, batch_first=False, | ||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | input_dropout=0, hidden_dropout=0, bidirectional=False): | ||||
@@ -122,18 +137,20 @@ class VarRNNBase(nn.Module): | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | ||||
cell = Cell(input_size, self.hidden_size, bias) | cell = Cell(input_size, self.hidden_size, bias) | ||||
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | |||||
self._all_cells.append(VarRnnCellWrapper( | |||||
cell, self.hidden_size, input_dropout, hidden_dropout)) | |||||
initial_parameter(self) | initial_parameter(self) | ||||
self.is_lstm = (self.mode == "LSTM") | self.is_lstm = (self.mode == "LSTM") | ||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | ||||
is_lstm = self.is_lstm | is_lstm = self.is_lstm | ||||
idx = self.num_directions * n_layer + n_direction | idx = self.num_directions * n_layer + n_direction | ||||
cell = self._all_cells[idx] | cell = self._all_cells[idx] | ||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | ||||
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||||
output_x, hidden_x = cell( | |||||
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||||
return output_x, hidden_x | return output_x, hidden_x | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
""" | """ | ||||
@@ -147,31 +164,38 @@ class VarRNNBase(nn.Module): | |||||
if not is_packed: | if not is_packed: | ||||
seq_len = x.size(1) if self.batch_first else x.size(0) | seq_len = x.size(1) if self.batch_first else x.size(0) | ||||
max_batch_size = x.size(0) if self.batch_first else x.size(1) | max_batch_size = x.size(0) if self.batch_first else x.size(1) | ||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | |||||
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) | |||||
seq_lens = torch.LongTensor( | |||||
[seq_len for _ in range(max_batch_size)]) | |||||
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||||
else: | else: | ||||
max_batch_size = int(input.batch_sizes[0]) | |||||
input, batch_sizes = input.data, input.batch_sizes | |||||
max_batch_size = int(x.batch_sizes[0]) | |||||
x, batch_sizes = x.data, x.batch_sizes | |||||
if hx is None: | if hx is None: | ||||
hx = x.new_zeros(self.num_layers * self.num_directions, | hx = x.new_zeros(self.num_layers * self.num_directions, | ||||
max_batch_size, self.hidden_size, requires_grad=True) | max_batch_size, self.hidden_size, requires_grad=True) | ||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | ||||
mask_x = x.new_ones((max_batch_size, self.input_size)) | mask_x = x.new_ones((max_batch_size, self.input_size)) | ||||
mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_out = x.new_ones( | |||||
(max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | ||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, | |||||
training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||||
training=self.training, inplace=True) | |||||
hidden = x.new_zeros( | |||||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
if is_lstm: | if is_lstm: | ||||
cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
cellstate = x.new_zeros( | |||||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
input_seq = PackedSequence(x, batch_sizes) | input_seq = PackedSequence(x, batch_sizes) | ||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
mask_h = nn.functional.dropout( | |||||
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | ||||
mask_x if layer == 0 else mask_out, mask_h) | mask_x if layer == 0 else mask_out, mask_h) | ||||
@@ -183,18 +207,19 @@ class VarRNNBase(nn.Module): | |||||
else: | else: | ||||
hidden[idx] = hidden_x | hidden[idx] = hidden_x | ||||
x = torch.cat(output_list, dim=-1) | x = torch.cat(output_list, dim=-1) | ||||
if is_lstm: | if is_lstm: | ||||
hidden = (hidden, cellstate) | hidden = (hidden, cellstate) | ||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(x, batch_sizes) | output = PackedSequence(x, batch_sizes) | ||||
else: | else: | ||||
x = PackedSequence(x, batch_sizes) | x = PackedSequence(x, batch_sizes) | ||||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | ||||
return output, hidden | return output, hidden | ||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM` | 别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM` | ||||
@@ -211,10 +236,11 @@ class VarLSTM(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||||
super(VarLSTM, self).__init__( | |||||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarLSTM, self).forward(x, hx) | return super(VarLSTM, self).forward(x, hx) | ||||
@@ -235,13 +261,15 @@ class VarRNN(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||||
super(VarRNN, self).__init__( | |||||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarRNN, self).forward(x, hx) | return super(VarRNN, self).forward(x, hx) | ||||
class VarGRU(VarRNNBase): | class VarGRU(VarRNNBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU` | 别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU` | ||||
@@ -258,10 +286,10 @@ class VarGRU(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||||
super(VarGRU, self).__init__( | |||||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarGRU, self).forward(x, hx) | return super(VarGRU, self).forward(x, hx) | ||||
@@ -1,5 +1,5 @@ | |||||
from functools import reduce | from functools import reduce | ||||
from collections import OrderedDict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -3,7 +3,7 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.modules.decoder.mlp import MLP | |||||
from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | ||||
@@ -120,8 +120,8 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
return {'pred_tags': pred_tags} | return {'pred_tags': pred_tags} | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
class CWSBiLSTMCRF(BaseModel): | class CWSBiLSTMCRF(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
@@ -10,8 +10,8 @@ from torch import nn | |||||
import torch | import torch | ||||
# from fastNLP.modules.encoder.transformer import TransformerEncoder | # from fastNLP.modules.encoder.transformer import TransformerEncoder | ||||
from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder | from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
class TransformerCWS(nn.Module): | class TransformerCWS(nn.Module): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
@@ -7,7 +7,7 @@ from fastNLP.io.config_io import ConfigSection | |||||
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.modules.decoder.mlp import MLP | |||||
from fastNLP.modules.encoder.embedding import Embedding as Embedding | from fastNLP.modules.encoder.embedding import Embedding as Embedding | ||||
from fastNLP.modules.encoder.lstm import LSTM | from fastNLP.modules.encoder.lstm import LSTM | ||||
@@ -5,7 +5,7 @@ import unittest | |||||
class TestCRF(unittest.TestCase): | class TestCRF(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
# 检查allowed_transitions()能否正确使用 | # 检查allowed_transitions()能否正确使用 | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | id2label = {0: 'B', 1: 'I', 2:'O'} | ||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | ||||
@@ -43,7 +43,7 @@ class TestCRF(unittest.TestCase): | |||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | ||||
pass | pass | ||||
# import torch | # import torch | ||||
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||||
# from fastNLP.modules.decoder.crf import seq_len_to_byte_mask | |||||
# | # | ||||
# labels = ['O'] | # labels = ['O'] | ||||
# for label in ['X', 'Y']: | # for label in ['X', 'Y']: | ||||
@@ -63,7 +63,7 @@ class TestCRF(unittest.TestCase): | |||||
# mask = seq_len_to_byte_mask(seq_lens) | # mask = seq_len_to_byte_mask(seq_lens) | ||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | # allen_res = allen_CRF.viterbi_tags(logits, mask) | ||||
# | # | ||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | ||||
@@ -91,7 +91,7 @@ class TestCRF(unittest.TestCase): | |||||
# mask = seq_len_to_byte_mask(seq_lens) | # mask = seq_len_to_byte_mask(seq_lens) | ||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | # allen_res = allen_CRF.viterbi_tags(logits, mask) | ||||
# | # | ||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | ||||
# encoding_type='BMES')) | # encoding_type='BMES')) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
@@ -104,7 +104,7 @@ class TestCRF(unittest.TestCase): | |||||
def test_case3(self): | def test_case3(self): | ||||
# 测试crf的loss不会出现负数 | # 测试crf的loss不会出现负数 | ||||
import torch | import torch | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||||
from fastNLP.core.utils import seq_len_to_mask | from fastNLP.core.utils import seq_len_to_mask | ||||
from torch import optim | from torch import optim | ||||
from torch import nn | from torch import nn | ||||