@@ -6,13 +6,14 @@ | |||||
 |  | ||||
[](http://fastnlp.readthedocs.io/?badge=latest) | [](http://fastnlp.readthedocs.io/?badge=latest) | ||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。 | |||||
- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等; | |||||
- 详尽的中文文档以供查阅; | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; | |||||
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; | |||||
- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等; | |||||
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅; | |||||
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | - 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | ||||
- 封装CNNText,Biaffine等模型可供直接使用; | |||||
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/) | |||||
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | - 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | ||||
@@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地 | |||||
fastNLP 依赖如下包: | fastNLP 依赖如下包: | ||||
+ numpy | |||||
+ torch>=0.4.0 | |||||
+ tqdm | |||||
+ nltk | |||||
+ numpy>=1.14.2 | |||||
+ torch>=1.0.0 | |||||
+ tqdm>=4.28.1 | |||||
+ nltk>=3.4.1 | |||||
+ requests | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。 | |||||
在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装 | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 | |||||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装 | |||||
```shell | ```shell | ||||
pip install fastNLP | pip install fastNLP | ||||
@@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助 | |||||
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | ||||
你可以在以下两个地方查看相关信息 | 你可以在以下两个地方查看相关信息 | ||||
- [介绍](reproduction/) | |||||
- [源码](fastNLP/models/) | |||||
- [模型介绍](reproduction/) | |||||
- [模型源码](fastNLP/models/) | |||||
## 项目结构 | ## 项目结构 | ||||
@@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | <td><b> fastNLP.core </b></td> | ||||
<td> 实现了核心功能,包括数据处理组件、训练器、测速器等 </td> | |||||
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -0,0 +1,88 @@ | |||||
import threading | |||||
import torch | |||||
from torch.nn.parallel.parallel_apply import get_a_var | |||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||||
from torch.nn.parallel.replicate import replicate | |||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||||
r"""Applies each `module` in :attr:`modules` in parallel on arguments | |||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) | |||||
on each of :attr:`devices`. | |||||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and | |||||
:attr:`devices` (if given) should all have same length. Moreover, each | |||||
element of :attr:`inputs` can either be a single object as the only argument | |||||
to a module, or a collection of positional arguments. | |||||
""" | |||||
assert len(modules) == len(inputs) | |||||
if kwargs_tup is not None: | |||||
assert len(modules) == len(kwargs_tup) | |||||
else: | |||||
kwargs_tup = ({},) * len(modules) | |||||
if devices is not None: | |||||
assert len(modules) == len(devices) | |||||
else: | |||||
devices = [None] * len(modules) | |||||
lock = threading.Lock() | |||||
results = {} | |||||
grad_enabled = torch.is_grad_enabled() | |||||
def _worker(i, module, input, kwargs, device=None): | |||||
torch.set_grad_enabled(grad_enabled) | |||||
if device is None: | |||||
device = get_a_var(input).get_device() | |||||
try: | |||||
with torch.cuda.device(device): | |||||
# this also avoids accidental slicing of `input` if it is a Tensor | |||||
if not isinstance(input, (list, tuple)): | |||||
input = (input,) | |||||
output = getattr(module, func_name)(*input, **kwargs) | |||||
with lock: | |||||
results[i] = output | |||||
except Exception as e: | |||||
with lock: | |||||
results[i] = e | |||||
if len(modules) > 1: | |||||
threads = [threading.Thread(target=_worker, | |||||
args=(i, module, input, kwargs, device)) | |||||
for i, (module, input, kwargs, device) in | |||||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||||
for thread in threads: | |||||
thread.start() | |||||
for thread in threads: | |||||
thread.join() | |||||
else: | |||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||||
outputs = [] | |||||
for i in range(len(inputs)): | |||||
output = results[i] | |||||
if isinstance(output, Exception): | |||||
raise output | |||||
outputs.append(output) | |||||
return outputs | |||||
def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
""" | |||||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||||
:param str, func_name: 对network中的这个函数进行多卡运行 | |||||
:param device_ids: nn.DataParallel中的device_ids | |||||
:param output_device: nn.DataParallel中的output_device | |||||
:return: | |||||
""" | |||||
def wrapper(network, *inputs, **kwargs): | |||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||||
if len(device_ids) == 1: | |||||
return getattr(network, func_name)(*inputs[0], **kwargs[0]) | |||||
replicas = replicate(network, device_ids[:len(inputs)]) | |||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||||
return gather(outputs, output_device) | |||||
return wrapper |
@@ -113,7 +113,7 @@ class Callback(object): | |||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
"""Trainer一共会运行多少步""" | |||||
"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||||
return self._trainer.n_steps | return self._trainer.n_steps | ||||
@property | @property | ||||
@@ -181,7 +181,7 @@ class Callback(object): | |||||
:param dict batch_x: DataSet中被设置为input的field的batch。 | :param dict batch_x: DataSet中被设置为input的field的batch。 | ||||
:param dict batch_y: DataSet中被设置为target的field的batch。 | :param dict batch_y: DataSet中被设置为target的field的batch。 | ||||
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 | :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 | ||||
情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 | |||||
情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -399,10 +399,11 @@ class GradientClipCallback(Callback): | |||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self.parameters is None: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.parameters, self.clip_value) | |||||
if self.step%self.update_every==0: | |||||
if self.parameters is None: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.parameters, self.clip_value) | |||||
class EarlyStopCallback(Callback): | class EarlyStopCallback(Callback): | ||||
@@ -20,6 +20,7 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..core.const import Const | |||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _CheckRes | from .utils import _CheckRes | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
class LossBase(object): | class LossBase(object): | ||||
""" | """ | ||||
所有loss的基类。如果想了解其中的原理,请查看源码。 | 所有loss的基类。如果想了解其中的原理,请查看源码。 | ||||
@@ -95,22 +97,7 @@ class LossBase(object): | |||||
# if func_spect.varargs: | # if func_spect.varargs: | ||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | ||||
# f"positional argument.).") | # f"positional argument.).") | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
:param dict pred_dict: 模型的forward函数返回的dict | :param dict pred_dict: 模型的forward函数返回的dict | ||||
@@ -118,11 +105,7 @@ class LossBase(object): | |||||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param: | |||||
loss = self.get_loss(**fast_param) | |||||
return loss | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and _param_map | # 1. check consistence between signature and _param_map | ||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
@@ -212,7 +195,6 @@ class LossFunc(LossBase): | |||||
if not isinstance(key_map, dict): | if not isinstance(key_map, dict): | ||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | ||||
self._init_param_map(key_map, **kwargs) | self._init_param_map(key_map, **kwargs) | ||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
@@ -226,6 +208,7 @@ class CrossEntropyLoss(LossBase): | |||||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | ||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | ||||
传入seq_len. | 传入seq_len. | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
Example:: | Example:: | ||||
@@ -233,21 +216,25 @@ class CrossEntropyLoss(LossBase): | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): | |||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.padding_idx = padding_idx | self.padding_idx = padding_idx | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if pred.dim()>2: | |||||
pred = pred.view(-1, pred.size(-1)) | |||||
target = target.view(-1) | |||||
if pred.dim() > 2: | |||||
if pred.size(1) != target.size(1): | |||||
pred = pred.transpose(1, 2) | |||||
pred = pred.reshape(-1, pred.size(-1)) | |||||
target = target.reshape(-1) | |||||
if seq_len is not None: | if seq_len is not None: | ||||
mask = seq_len_to_mask(seq_len).view(-1).eq(0) | |||||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||||
target = target.masked_fill(mask, self.padding_idx) | target = target.masked_fill(mask, self.padding_idx) | ||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx) | |||||
ignore_index=self.padding_idx, reduction=self.reduction) | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
@@ -258,15 +245,18 @@ class L1Loss(LossBase): | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.l1_loss(input=pred, target=target) | |||||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
@@ -277,14 +267,17 @@ class BCELoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.binary_cross_entropy(input=pred, target=target) | |||||
return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
@@ -295,14 +288,20 @@ class NLLLoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||||
传入seq_len. | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
self.ignore_idx = ignore_idx | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.nll_loss(input=pred, target=target) | |||||
return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
@@ -314,7 +313,7 @@ class LossInForward(LossBase): | |||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | :param str loss_key: 在forward函数中loss的键名,默认为loss | ||||
""" | """ | ||||
def __init__(self, loss_key='loss'): | |||||
def __init__(self, loss_key=Const.LOSS): | |||||
super().__init__() | super().__init__() | ||||
if not isinstance(loss_key, str): | if not isinstance(loss_key, str): | ||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | ||||
@@ -9,6 +9,9 @@ __all__ = [ | |||||
] | ] | ||||
import torch | import torch | ||||
import math | |||||
import torch | |||||
from torch.optim.optimizer import Optimizer as TorchOptimizer | |||||
class Optimizer(object): | class Optimizer(object): | ||||
@@ -97,3 +100,110 @@ class Adam(Optimizer): | |||||
return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings) | return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings) | ||||
else: | else: | ||||
return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings) | return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings) | ||||
class AdamW(TorchOptimizer): | |||||
r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. | |||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. | |||||
Arguments: | |||||
params (iterable): iterable of parameters to optimize or dicts defining | |||||
parameter groups | |||||
lr (float, optional): learning rate (default: 1e-3) | |||||
betas (Tuple[float, float], optional): coefficients used for computing | |||||
running averages of gradient and its square (default: (0.9, 0.99)) | |||||
eps (float, optional): term added to the denominator to improve | |||||
numerical stability (default: 1e-8) | |||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this | |||||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||||
(default: False) | |||||
.. _Adam\: A Method for Stochastic Optimization: | |||||
https://arxiv.org/abs/1412.6980 | |||||
.. _Decoupled Weight Decay Regularization: | |||||
https://arxiv.org/abs/1711.05101 | |||||
.. _On the Convergence of Adam and Beyond: | |||||
https://openreview.net/forum?id=ryQu7f-RZ | |||||
""" | |||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |||||
weight_decay=1e-2, amsgrad=False): | |||||
if not 0.0 <= lr: | |||||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||||
if not 0.0 <= eps: | |||||
raise ValueError("Invalid epsilon value: {}".format(eps)) | |||||
if not 0.0 <= betas[0] < 1.0: | |||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||||
if not 0.0 <= betas[1] < 1.0: | |||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||||
defaults = dict(lr=lr, betas=betas, eps=eps, | |||||
weight_decay=weight_decay, amsgrad=amsgrad) | |||||
super(AdamW, self).__init__(params, defaults) | |||||
def __setstate__(self, state): | |||||
super(AdamW, self).__setstate__(state) | |||||
for group in self.param_groups: | |||||
group.setdefault('amsgrad', False) | |||||
def step(self, closure=None): | |||||
"""Performs a single optimization step. | |||||
Arguments: | |||||
closure (callable, optional): A closure that reevaluates the model | |||||
and returns the loss. | |||||
""" | |||||
loss = None | |||||
if closure is not None: | |||||
loss = closure() | |||||
for group in self.param_groups: | |||||
for p in group['params']: | |||||
if p.grad is None: | |||||
continue | |||||
# Perform stepweight decay | |||||
p.data.mul_(1 - group['lr'] * group['weight_decay']) | |||||
# Perform optimization step | |||||
grad = p.grad.data | |||||
if grad.is_sparse: | |||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | |||||
amsgrad = group['amsgrad'] | |||||
state = self.state[p] | |||||
# State initialization | |||||
if len(state) == 0: | |||||
state['step'] = 0 | |||||
# Exponential moving average of gradient values | |||||
state['exp_avg'] = torch.zeros_like(p.data) | |||||
# Exponential moving average of squared gradient values | |||||
state['exp_avg_sq'] = torch.zeros_like(p.data) | |||||
if amsgrad: | |||||
# Maintains max of all exp. moving avg. of sq. grad. values | |||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data) | |||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |||||
if amsgrad: | |||||
max_exp_avg_sq = state['max_exp_avg_sq'] | |||||
beta1, beta2 = group['betas'] | |||||
state['step'] += 1 | |||||
# Decay the first and second moment running average coefficient | |||||
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |||||
if amsgrad: | |||||
# Maintains the maximum of all 2nd moment running avg. till now | |||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | |||||
# Use the max. for normalizing running avg. of gradient | |||||
denom = max_exp_avg_sq.sqrt().add_(group['eps']) | |||||
else: | |||||
denom = exp_avg_sq.sqrt().add_(group['eps']) | |||||
bias_correction1 = 1 - beta1 ** state['step'] | |||||
bias_correction2 = 1 - beta2 ** state['step'] | |||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | |||||
p.data.addcdiv_(-step_size, exp_avg, denom) | |||||
return loss |
@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||||
class Predictor(object): | class Predictor(object): | ||||
""" | """ | ||||
An interface for predicting outputs based on trained models. | |||||
一个根据训练模型预测输出的预测器(Predictor) | |||||
It does not care about evaluations of the model, which is different from Tester. | |||||
This is a high-level model wrapper to be called by FastNLP. | |||||
This class does not share any operations with Trainer and Tester. | |||||
Currently, Predictor does not support GPU. | |||||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||||
""" | """ | ||||
def __init__(self, network): | def __init__(self, network): | ||||
@@ -30,18 +30,19 @@ class Predictor(object): | |||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def predict(self, data, seq_len_field_name=None): | |||||
"""Perform inference using the trained model. | |||||
def predict(self, data: DataSet, seq_len_field_name=None): | |||||
"""用已经训练好的模型进行inference. | |||||
:param data: a DataSet object. | |||||
:param str seq_len_field_name: field name indicating sequence lengths | |||||
:return: list of batch outputs | |||||
:param fastNLP.DataSet data: 待预测的数据集 | |||||
:param str seq_len_field_name: 表示序列长度信息的field名字 | |||||
:return: dict dict里面的内容为模型预测的结果 | |||||
""" | """ | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | ||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | ||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | ||||
prev_training = self.network.training | |||||
self.network.eval() | self.network.eval() | ||||
network_device = _get_model_device(self.network) | network_device = _get_model_device(self.network) | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
@@ -74,4 +75,5 @@ class Predictor(object): | |||||
else: | else: | ||||
batch_output[key].append(value) | batch_output[key].append(value) | ||||
self.network.train(prev_training) | |||||
return batch_output | return batch_output |
@@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||||
""" | """ | ||||
import warnings | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -48,6 +46,8 @@ from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _data_parallel_wrapper | |||||
from functools import partial | |||||
__all__ = [ | __all__ = [ | ||||
"Tester" | "Tester" | ||||
@@ -104,26 +104,27 @@ class Tester(object): | |||||
self.data_iterator = data | self.data_iterator = data | ||||
else: | else: | ||||
raise TypeError("data type {} not support".format(type(data))) | raise TypeError("data type {} not support".format(type(data))) | ||||
# 如果是DataParallel将没有办法使用predict方法 | |||||
if isinstance(self._model, nn.DataParallel): | |||||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | |||||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | |||||
" while DataParallel has no predict() function.") | |||||
self._model = self._model.module | |||||
# check predict | # check predict | ||||
if hasattr(self._model, 'predict'): | |||||
self._predict_func = self._model.predict | |||||
if not callable(self._predict_func): | |||||
_model_name = model.__class__.__name__ | |||||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | |||||
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ | |||||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | |||||
callable(self._model.module.predict)): | |||||
if isinstance(self._model, nn.DataParallel): | |||||
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', | |||||
self._model.device_ids, | |||||
self._model.output_device), | |||||
network=self._model.module) | |||||
self._predict_func = self._model.module.predict | |||||
else: | |||||
self._predict_func = self._model.predict | |||||
self._predict_func_wrapper = self._model.predict | |||||
else: | else: | ||||
if isinstance(model, nn.DataParallel): | |||||
if isinstance(self._model, nn.DataParallel): | |||||
self._predict_func_wrapper = self._model.forward | |||||
self._predict_func = self._model.module.forward | self._predict_func = self._model.module.forward | ||||
else: | else: | ||||
self._predict_func = self._model.forward | self._predict_func = self._model.forward | ||||
self._predict_func_wrapper = self._model.forward | |||||
def test(self): | def test(self): | ||||
"""开始进行验证,并返回验证结果。 | """开始进行验证,并返回验证结果。 | ||||
@@ -180,7 +181,7 @@ class Tester(object): | |||||
def _data_forward(self, func, x): | def _data_forward(self, func, x): | ||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
x = _build_args(func, **x) | x = _build_args(func, **x) | ||||
y = func(**x) | |||||
y = self._predict_func_wrapper(**x) | |||||
return y | return y | ||||
def _format_eval_results(self, results): | def _format_eval_results(self, results): | ||||
@@ -454,7 +454,7 @@ class Trainer(object): | |||||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | ||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | ||||
metric_key=metric_key, check_level=check_code_level, | |||||
metric_key=self.metric_key, check_level=check_code_level, | |||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | ||||
self.model = _move_model_to_device(model, device=device) | self.model = _move_model_to_device(model, device=device) | ||||
@@ -473,7 +473,7 @@ class Trainer(object): | |||||
self.best_dev_step = None | self.best_dev_step = None | ||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -17,7 +17,6 @@ import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -278,6 +277,7 @@ def _move_model_to_device(model, device): | |||||
return model | return model | ||||
def _get_model_device(model): | def _get_model_device(model): | ||||
""" | """ | ||||
传入一个nn.Module的模型,获取它所在的device | 传入一个nn.Module的模型,获取它所在的device | ||||
@@ -11,21 +11,35 @@ | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | 'EmbedLoader', | ||||
'DataInfo', | |||||
'DataSetLoader', | 'DataSetLoader', | ||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
'SSTLoader', | |||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .base_loader import DataInfo, DataSetLoader | |||||
from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \ | |||||
PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
from .data_loader.sst import SSTLoader | |||||
from .data_loader.matching import MatchingLoader, SNLILoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -10,6 +10,7 @@ from typing import Union, Dict | |||||
import os | import os | ||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
class BaseLoader(object): | class BaseLoader(object): | ||||
""" | """ | ||||
各个 Loader 的基类,提供了 API 的参考。 | 各个 Loader 的基类,提供了 API 的参考。 | ||||
@@ -55,8 +56,6 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
def _download_from_url(url, path): | def _download_from_url(url, path): | ||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -115,13 +114,11 @@ class DataInfo: | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | ||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | ||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | ||||
""" | """ | ||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||||
self.vocabs = vocabs or {} | self.vocabs = vocabs or {} | ||||
self.embeddings = embeddings or {} | |||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -133,6 +130,7 @@ class DataInfo: | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | _str += '\t{} has {} entries.\n'.format(name, len(vocab)) | ||||
return _str | return _str | ||||
class DataSetLoader: | class DataSetLoader: | ||||
""" | """ | ||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | ||||
@@ -213,7 +211,6 @@ class DataSetLoader: | |||||
返回的 :class:`DataInfo` 对象有如下属性: | 返回的 :class:`DataInfo` 对象有如下属性: | ||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | - vocabs: 由从数据集中获取的词表组成的字典,每个词表 | ||||
- embeddings: (可选) 数据集对应的词嵌入 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | ||||
:param paths: 原始数据读取的路径 | :param paths: 原始数据读取的路径 | ||||
@@ -0,0 +1,19 @@ | |||||
""" | |||||
用于读数据集的模块, 具体包括: | |||||
这些模块的使用方法如下: | |||||
""" | |||||
__all__ = [ | |||||
'SSTLoader', | |||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
] | |||||
from .sst import SSTLoader | |||||
from .matching import MatchingLoader, SNLILoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -0,0 +1,430 @@ | |||||
import os | |||||
from typing import Union, Dict | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..base_loader import DataInfo, DataSetLoader | |||||
from ..dataset_loader import JsonLoader, CSVLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ...modules.encoder._bert import BertTokenizer | |||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 待读取数据集的路径名 | |||||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||||
的原始字符串文本,第三个为标签 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||||
""" | |||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||||
对应的全路径文件名。 | |||||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||||
这个数据集的名字,如果不定义则默认为train。 | |||||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||||
:param bool get_index: 是否需要根据词表将文本转为index | |||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||||
于此同时其他field不会被设置为input。默认值为True。 | |||||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||||
:return: | |||||
""" | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataInfo() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.set_target(Const.TARGET) | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||||
lines = f.readlines() | |||||
lines = [line.strip() for line in lines] | |||||
words_vocab.add_word_lst(lines) | |||||
words_vocab.build_vocab() | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is None: | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)], | |||||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||||
if 'train' not in n]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||||
return data_info | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -1,11 +1,14 @@ | |||||
from typing import Iterable | from typing import Iterable | ||||
from nltk import Tree | from nltk import Tree | ||||
import spacy | |||||
from ..base_loader import DataInfo, DataSetLoader | from ..base_loader import DataInfo, DataSetLoader | ||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..embed_loader import EmbeddingOption, EmbedLoader | from ..embed_loader import EmbeddingOption, EmbedLoader | ||||
spacy.prefer_gpu() | |||||
sptk = spacy.load('en') | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | ||||
@@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader): | |||||
def _get_one(data, subtree): | def _get_one(data, subtree): | ||||
tree = Tree.fromstring(data) | tree = Tree.fromstring(data) | ||||
if subtree: | if subtree: | ||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||||
return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||||
def process(self, | def process(self, | ||||
paths, | paths, | ||||
@@ -16,8 +16,6 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from .base_loader import DataSetLoader, DataInfo | from .base_loader import DataSetLoader, DataInfo | ||||
from .data_loader.sst import SSTLoader | |||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder._bert import BertTokenizer | from ..modules.encoder._bert import BertTokenizer | ||||
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
else: | else: | ||||
instance = Instance(words=sent_words) | instance = Instance(words=sent_words) | ||||
data_set.append(instance) | data_set.append(instance) | ||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||||
return data_set | return data_set | ||||
@@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader): | |||||
return ds | return ds | ||||
class SNLILoader(JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self): | |||||
fields = { | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
super(SNLILoader, self).__init__(fields=fields) | |||||
def _load(self, path): | |||||
ds = super(SNLILoader, self)._load(path) | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class CSVLoader(DataSetLoader): | class CSVLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | ||||
@@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
continue | continue | ||||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||||
raise ValueError('invalid instance ends at line: {}'.format(line_idx)) | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
@@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
return | return | ||||
print('invalid instance at line: {}'.format(line_idx)) | |||||
print('invalid instance ends at line: {}'.format(line_idx)) | |||||
raise e | raise e |
@@ -8,35 +8,7 @@ from torch import nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
class BertConfig: | |||||
def __init__( | |||||
self, | |||||
vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02 | |||||
): | |||||
self.vocab_size = vocab_size | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.intermediate_size = intermediate_size | |||||
self.hidden_act = hidden_act | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
from ..modules.encoder._bert import BertConfig | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
pooled_output = self.dropout(pooled_output) | pooled_output = self.dropout(pooled_output) | ||||
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, 1) | self.classifier = nn.Linear(config.hidden_size, 1) | ||||
@classmethod | |||||
def from_pretrained(cls, num_choices, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | ||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | ||||
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
sequence_output = self.dropout(sequence_output) | sequence_output = self.dropout(sequence_output) | ||||
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | ||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | # self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
logits = self.qa_outputs(sequence_output) | logits = self.qa_outputs(sequence_output) | ||||
@@ -46,7 +46,7 @@ class StarTransEnc(nn.Module): | |||||
super(StarTransEnc, self).__init__() | super(StarTransEnc, self).__init__() | ||||
self.embedding = get_embeddings(init_embed) | self.embedding = get_embeddings(init_embed) | ||||
emb_dim = self.embedding.embedding_dim | emb_dim = self.embedding.embedding_dim | ||||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||||
#self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||||
self.emb_drop = nn.Dropout(emb_dropout) | self.emb_drop = nn.Dropout(emb_dropout) | ||||
self.encoder = StarTransformer(hidden_size=hidden_size, | self.encoder = StarTransformer(hidden_size=hidden_size, | ||||
num_layers=num_layers, | num_layers=num_layers, | ||||
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module): | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | [batch, hidden] 全局 relay 节点, 详见论文 | ||||
""" | """ | ||||
x = self.embedding(x) | x = self.embedding(x) | ||||
x = self.emb_fc(self.emb_drop(x)) | |||||
#x = self.emb_fc(self.emb_drop(x)) | |||||
nodes, relay = self.encoder(x, mask) | nodes, relay = self.encoder(x, mask) | ||||
return nodes, relay | return nodes, relay | ||||
@@ -205,7 +205,7 @@ class STSeqCls(nn.Module): | |||||
max_len=max_len, | max_len=max_len, | ||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | |||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | """ | ||||
@@ -15,7 +15,8 @@ class MLP(nn.Module): | |||||
多层感知器 | 多层感知器 | ||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | ||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||||
sigmoid,默认值为relu | |||||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | ||||
:param str initial_method: 参数初始化方式 | :param str initial_method: 参数初始化方式 | ||||
:param float dropout: dropout概率,默认值为0 | :param float dropout: dropout概率,默认值为0 | ||||
@@ -2,7 +2,8 @@ | |||||
""" | """ | ||||
这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 | |||||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | |||||
有用,也请引用一下他们。 | |||||
""" | """ | ||||
@@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary | |||||
import collections | import collections | ||||
import unicodedata | import unicodedata | ||||
from ...io.file_utils import _get_base_url, cached_path | |||||
import numpy as np | import numpy as np | ||||
from itertools import chain | from itertools import chain | ||||
import copy | import copy | ||||
@@ -22,9 +22,106 @@ import os | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import glob | import glob | ||||
import sys | |||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
class BertConfig(object): | |||||
"""Configuration class to store the configuration of a `BertModel`. | |||||
""" | |||||
def __init__(self, | |||||
vocab_size_or_config_json_file, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02, | |||||
layer_norm_eps=1e-12): | |||||
"""Constructs BertConfig. | |||||
Args: | |||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | |||||
hidden_size: Size of the encoder layers and the pooler layer. | |||||
num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||||
num_attention_heads: Number of attention heads for each attention layer in | |||||
the Transformer encoder. | |||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |||||
layer in the Transformer encoder. | |||||
hidden_act: The non-linear activation function (function or string) in the | |||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected | |||||
layers in the embeddings, encoder, and pooler. | |||||
attention_probs_dropout_prob: The dropout ratio for the attention | |||||
probabilities. | |||||
max_position_embeddings: The maximum sequence length that this model might | |||||
ever be used with. Typically set this to something large just in case | |||||
(e.g., 512 or 1024 or 2048). | |||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into | |||||
`BertModel`. | |||||
initializer_range: The sttdev of the truncated_normal_initializer for | |||||
initializing all weight matrices. | |||||
layer_norm_eps: The epsilon used by LayerNorm. | |||||
""" | |||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 | |||||
and isinstance(vocab_size_or_config_json_file, unicode)): | |||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: | |||||
json_config = json.loads(reader.read()) | |||||
for key, value in json_config.items(): | |||||
self.__dict__[key] = value | |||||
elif isinstance(vocab_size_or_config_json_file, int): | |||||
self.vocab_size = vocab_size_or_config_json_file | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.hidden_act = hidden_act | |||||
self.intermediate_size = intermediate_size | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
self.layer_norm_eps = layer_norm_eps | |||||
else: | |||||
raise ValueError("First argument must be either a vocabulary size (int)" | |||||
"or the path to a pretrained model config file (str)") | |||||
@classmethod | |||||
def from_dict(cls, json_object): | |||||
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||||
config = BertConfig(vocab_size_or_config_json_file=-1) | |||||
for key, value in json_object.items(): | |||||
config.__dict__[key] = value | |||||
return config | |||||
@classmethod | |||||
def from_json_file(cls, json_file): | |||||
"""Constructs a `BertConfig` from a json file of parameters.""" | |||||
with open(json_file, "r", encoding='utf-8') as reader: | |||||
text = reader.read() | |||||
return cls.from_dict(json.loads(text)) | |||||
def __repr__(self): | |||||
return str(self.to_json_string()) | |||||
def to_dict(self): | |||||
"""Serializes this instance to a Python dictionary.""" | |||||
output = copy.deepcopy(self.__dict__) | |||||
return output | |||||
def to_json_string(self): | |||||
"""Serializes this instance to a JSON string.""" | |||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |||||
def to_json_file(self, json_file_path): | |||||
""" Save this instance to a json file.""" | |||||
with open(json_file_path, "w", encoding='utf-8') as writer: | |||||
writer.write(self.to_json_string()) | |||||
def gelu(x): | def gelu(x): | ||||
@@ -40,6 +137,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | class BertLayerNorm(nn.Module): | ||||
def __init__(self, hidden_size, eps=1e-12): | def __init__(self, hidden_size, eps=1e-12): | ||||
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
""" | |||||
super(BertLayerNorm, self).__init__() | super(BertLayerNorm, self).__init__() | ||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | self.weight = nn.Parameter(torch.ones(hidden_size)) | ||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | self.bias = nn.Parameter(torch.zeros(hidden_size)) | ||||
@@ -53,16 +152,18 @@ class BertLayerNorm(nn.Module): | |||||
class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
"""Construct the embeddings from word, position and token_type embeddings. | |||||
""" | |||||
def __init__(self, config): | |||||
super(BertEmbeddings, self).__init__() | super(BertEmbeddings, self).__init__() | ||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) | |||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | ||||
# any TensorFlow checkpoint file | # any TensorFlow checkpoint file | ||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | def forward(self, input_ids, token_type_ids=None): | ||||
seq_length = input_ids.size(1) | seq_length = input_ids.size(1) | ||||
@@ -82,21 +183,21 @@ class BertEmbeddings(nn.Module): | |||||
class BertSelfAttention(nn.Module): | class BertSelfAttention(nn.Module): | ||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
def __init__(self, config): | |||||
super(BertSelfAttention, self).__init__() | super(BertSelfAttention, self).__init__() | ||||
if hidden_size % num_attention_heads != 0: | |||||
if config.hidden_size % config.num_attention_heads != 0: | |||||
raise ValueError( | raise ValueError( | ||||
"The hidden size (%d) is not a multiple of the number of attention " | "The hidden size (%d) is not a multiple of the number of attention " | ||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) | |||||
self.num_attention_heads = config.num_attention_heads | |||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | self.all_head_size = self.num_attention_heads * self.attention_head_size | ||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | def transpose_for_scores(self, x): | ||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | ||||
@@ -133,11 +234,11 @@ class BertSelfAttention(nn.Module): | |||||
class BertSelfOutput(nn.Module): | class BertSelfOutput(nn.Module): | ||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
def __init__(self, config): | |||||
super(BertSelfOutput, self).__init__() | super(BertSelfOutput, self).__init__() | ||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | def forward(self, hidden_states, input_tensor): | ||||
hidden_states = self.dense(hidden_states) | hidden_states = self.dense(hidden_states) | ||||
@@ -147,10 +248,10 @@ class BertSelfOutput(nn.Module): | |||||
class BertAttention(nn.Module): | class BertAttention(nn.Module): | ||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
def __init__(self, config): | |||||
super(BertAttention, self).__init__() | super(BertAttention, self).__init__() | ||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
self.self = BertSelfAttention(config) | |||||
self.output = BertSelfOutput(config) | |||||
def forward(self, input_tensor, attention_mask): | def forward(self, input_tensor, attention_mask): | ||||
self_output = self.self(input_tensor, attention_mask) | self_output = self.self(input_tensor, attention_mask) | ||||
@@ -159,11 +260,13 @@ class BertAttention(nn.Module): | |||||
class BertIntermediate(nn.Module): | class BertIntermediate(nn.Module): | ||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
def __init__(self, config): | |||||
super(BertIntermediate, self).__init__() | super(BertIntermediate, self).__init__() | ||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): | |||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||||
else: | |||||
self.intermediate_act_fn = config.hidden_act | |||||
def forward(self, hidden_states): | def forward(self, hidden_states): | ||||
hidden_states = self.dense(hidden_states) | hidden_states = self.dense(hidden_states) | ||||
@@ -172,11 +275,11 @@ class BertIntermediate(nn.Module): | |||||
class BertOutput(nn.Module): | class BertOutput(nn.Module): | ||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
def __init__(self, config): | |||||
super(BertOutput, self).__init__() | super(BertOutput, self).__init__() | ||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | def forward(self, hidden_states, input_tensor): | ||||
hidden_states = self.dense(hidden_states) | hidden_states = self.dense(hidden_states) | ||||
@@ -186,13 +289,11 @@ class BertOutput(nn.Module): | |||||
class BertLayer(nn.Module): | class BertLayer(nn.Module): | ||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
def __init__(self, config): | |||||
super(BertLayer, self).__init__() | super(BertLayer, self).__init__() | ||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
self.attention = BertAttention(config) | |||||
self.intermediate = BertIntermediate(config) | |||||
self.output = BertOutput(config) | |||||
def forward(self, hidden_states, attention_mask): | def forward(self, hidden_states, attention_mask): | ||||
attention_output = self.attention(hidden_states, attention_mask) | attention_output = self.attention(hidden_states, attention_mask) | ||||
@@ -202,13 +303,10 @@ class BertLayer(nn.Module): | |||||
class BertEncoder(nn.Module): | class BertEncoder(nn.Module): | ||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
def __init__(self, config): | |||||
super(BertEncoder, self).__init__() | super(BertEncoder, self).__init__() | ||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
layer = BertLayer(config) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | ||||
all_encoder_layers = [] | all_encoder_layers = [] | ||||
@@ -222,9 +320,9 @@ class BertEncoder(nn.Module): | |||||
class BertPooler(nn.Module): | class BertPooler(nn.Module): | ||||
def __init__(self, hidden_size): | |||||
def __init__(self, config): | |||||
super(BertPooler, self).__init__() | super(BertPooler, self).__init__() | ||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||||
self.activation = nn.Tanh() | self.activation = nn.Tanh() | ||||
def forward(self, hidden_states): | def forward(self, hidden_states): | ||||
@@ -242,13 +340,19 @@ class BertModel(nn.Module): | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | 如果你想使用预训练好的权重矩阵,请在以下网址下载. | ||||
sources:: | sources:: | ||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", | |||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" | |||||
用预训练权重矩阵来建立BERT模型:: | 用预训练权重矩阵来建立BERT模型:: | ||||
@@ -272,34 +376,30 @@ class BertModel(nn.Module): | |||||
:param int initializer_range: 初始化权重范围,默认值为0.02 | :param int initializer_range: 初始化权重范围,默认值为0.02 | ||||
""" | """ | ||||
def __init__(self, vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02): | |||||
def __init__(self, config, *inputs, **kwargs): | |||||
super(BertModel, self).__init__() | super(BertModel, self).__init__() | ||||
self.hidden_size = hidden_size | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
if not isinstance(config, BertConfig): | |||||
raise ValueError( | |||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. " | |||||
"To create a model from a Google pretrained model use " | |||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |||||
self.__class__.__name__, self.__class__.__name__ | |||||
)) | |||||
super(BertModel, self).__init__() | |||||
self.config = config | |||||
self.hidden_size = self.config.hidden_size | |||||
self.embeddings = BertEmbeddings(config) | |||||
self.encoder = BertEncoder(config) | |||||
self.pooler = BertPooler(config) | |||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
""" Initialize the weights. | |||||
""" | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | if isinstance(module, (nn.Linear, nn.Embedding)): | ||||
# Slightly different from the TF version which uses truncated_normal for initialization | # Slightly different from the TF version which uses truncated_normal for initialization | ||||
# cf https://github.com/pytorch/pytorch/pull/5617 | # cf https://github.com/pytorch/pytorch/pull/5617 | ||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | elif isinstance(module, BertLayerNorm): | ||||
module.bias.data.zero_() | module.bias.data.zero_() | ||||
module.weight.data.fill_(1.0) | module.weight.data.fill_(1.0) | ||||
@@ -338,14 +438,19 @@ class BertModel(nn.Module): | |||||
return encoded_layers, pooled_output | return encoded_layers, pooled_output | ||||
@classmethod | @classmethod | ||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): | |||||
state_dict = kwargs.get('state_dict', None) | |||||
kwargs.pop('state_dict', None) | |||||
cache_dir = kwargs.get('cache_dir', None) | |||||
kwargs.pop('cache_dir', None) | |||||
from_tf = kwargs.get('from_tf', False) | |||||
kwargs.pop('from_tf', None) | |||||
# Load config | # Load config | ||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | ||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | # logger.info("Model config {}".format(config)) | ||||
# Instantiate model. | # Instantiate model. | ||||
model = cls(*inputs, **config, **kwargs) | |||||
model = cls(config, *inputs, **kwargs) | |||||
if state_dict is None: | if state_dict is None: | ||||
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) | files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) | ||||
if len(files)==0: | if len(files)==0: | ||||
@@ -353,7 +458,7 @@ class BertModel(nn.Module): | |||||
elif len(files)>1: | elif len(files)>1: | ||||
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") | raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") | ||||
weights_path = files[0] | weights_path = files[0] | ||||
state_dict = torch.load(weights_path) | |||||
state_dict = torch.load(weights_path, map_location='cpu') | |||||
old_keys = [] | old_keys = [] | ||||
new_keys = [] | new_keys = [] | ||||
@@ -464,6 +569,7 @@ class WordpieceTokenizer(object): | |||||
output_tokens.extend(sub_tokens) | output_tokens.extend(sub_tokens) | ||||
return output_tokens | return output_tokens | ||||
def load_vocab(vocab_file): | def load_vocab(vocab_file): | ||||
"""Loads a vocabulary file into a dictionary.""" | """Loads a vocabulary file into a dictionary.""" | ||||
vocab = collections.OrderedDict() | vocab = collections.OrderedDict() | ||||
@@ -594,6 +700,7 @@ class BasicTokenizer(object): | |||||
output.append(char) | output.append(char) | ||||
return "".join(output) | return "".join(output) | ||||
def _is_whitespace(char): | def _is_whitespace(char): | ||||
"""Checks whether `chars` is a whitespace character.""" | """Checks whether `chars` is a whitespace character.""" | ||||
# \t, \n, and \r are technically contorl characters but we treat them | # \t, \n, and \r are technically contorl characters but we treat them | ||||
@@ -840,6 +947,7 @@ class _WordBertModel(nn.Module): | |||||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | ||||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | ||||
attn_masks[i, :len(word_pieces_i)+2].fill_(1) | attn_masks[i, :len(word_pieces_i)+2].fill_(1) | ||||
# TODO 截掉长度超过的部分。 | |||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | ||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | ||||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | ||||
@@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding): | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
# 读取embedding | # 读取embedding | ||||
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||||
normalize=normalize) | normalize=normalize) | ||||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | ||||
padding_idx=vocab.padding_idx, | padding_idx=vocab.padding_idx, | ||||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | max_norm=None, norm_type=2, scale_grad_by_freq=False, | ||||
sparse=False, _weight=embedding) | sparse=False, _weight=embedding) | ||||
if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk | |||||
words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||||
for word, idx in vocab: | |||||
if vocab._is_word_no_create_entry(word) and not hit_flags[idx]: | |||||
words_to_words[idx] = vocab.unknown_idx | |||||
self.words_to_words = words_to_words | |||||
self._embed_size = self.embedding.weight.size(1) | self._embed_size = self.embedding.weight.size(1) | ||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
@@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
dim = len(parts) - 1 | dim = len(parts) - 1 | ||||
f.seek(0) | f.seek(0) | ||||
matrix = torch.zeros(len(vocab), dim) | |||||
if init_method is not None: | |||||
init_method(matrix) | |||||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||||
matrix = {} | |||||
found_count = 0 | |||||
for idx, line in enumerate(f, start_idx): | for idx, line in enumerate(f, start_idx): | ||||
try: | try: | ||||
parts = line.strip().split() | parts = line.strip().split() | ||||
@@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding): | |||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | ||||
hit_flags[index] = True | |||||
found_count += 1 | |||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
warnings.warn("Error occurred at the {} line.".format(idx)) | warnings.warn("Error occurred at the {} line.".format(idx)) | ||||
else: | else: | ||||
print("Error occurred at the {} line.".format(idx)) | print("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
found_count = sum(hit_flags) | |||||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | ||||
if init_method is None: | |||||
if len(vocab)-found_count>0 and found_count>0: # 有的没找到 | |||||
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] | |||||
mean = found_vecs.mean(dim=0, keepdim=True) | |||||
std = found_vecs.std(dim=0, keepdim=True) | |||||
unfound_vec_num = np.sum(hit_flags==False) | |||||
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean | |||||
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs | |||||
for word, index in vocab: | |||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | |||||
else: | |||||
matrix[index] = None | |||||
vectors = torch.zeros(len(matrix), dim) | |||||
if init_method: | |||||
init_method(vectors) | |||||
else: | |||||
nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim)) | |||||
if vocab._no_create_word_length>0: | |||||
if vocab.unknown is None: # 创建一个专门的unknown | |||||
unknown_idx = len(matrix) | |||||
vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() | |||||
else: | |||||
unknown_idx = vocab.unknown_idx | |||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||||
requires_grad=False) | |||||
for order, (index, vec) in enumerate(matrix.items()): | |||||
if vec is not None: | |||||
vectors[order] = vec | |||||
words_to_words[index] = order | |||||
self.words_to_words = words_to_words | |||||
else: | |||||
for index, vec in matrix.items(): | |||||
if vec is not None: | |||||
vectors[index] = vec | |||||
if normalize: | if normalize: | ||||
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) | |||||
vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12) | |||||
return matrix, hit_flags | |||||
return vectors | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -35,11 +35,13 @@ class StarTransformer(nn.Module): | |||||
self.iters = num_layers | self.iters = num_layers | ||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | ||||
self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||||
self.emb_drop = nn.Dropout(dropout) | |||||
self.ring_att = nn.ModuleList( | self.ring_att = nn.ModuleList( | ||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
self.star_att = nn.ModuleList( | self.star_att = nn.ModuleList( | ||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
if max_len is not None: | if max_len is not None: | ||||
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module): | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | ||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | ||||
if self.pos_emb: | |||||
if self.pos_emb and False: | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | ||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | ||||
embs = embs + P | embs = embs + P | ||||
embs = norm_func(self.emb_drop, embs) | |||||
nodes = embs | nodes = embs | ||||
relay = embs.mean(2, keepdim=True) | relay = embs.mean(2, keepdim=True) | ||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ||||
r_embs = embs.view(B, H, 1, L) | r_embs = embs.view(B, H, 1, L) | ||||
for i in range(self.iters): | for i in range(self.iters): | ||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ||||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | ||||
nodes = nodes.masked_fill_(ex_mask, 0) | nodes = nodes.masked_fill_(ex_mask, 0) | ||||
@@ -3,6 +3,8 @@ | |||||
复现的模型有: | 复现的模型有: | ||||
- [Star-Transformer](Star_transformer/) | - [Star-Transformer](Star_transformer/) | ||||
- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239) | |||||
- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12) | |||||
- ... | - ... | ||||
# 任务复现 | # 任务复现 | ||||
@@ -11,11 +13,11 @@ | |||||
## Matching (自然语言推理/句子匹配) | ## Matching (自然语言推理/句子匹配) | ||||
- still in progress | |||||
- [Matching 任务复现](matching) | |||||
## Sequence Labeling (序列标注) | ## Sequence Labeling (序列标注) | ||||
- still in progress | |||||
- [NER](seqence_labelling/ner) | |||||
## Coreference resolution (指代消解) | ## Coreference resolution (指代消解) | ||||
@@ -2,7 +2,8 @@ import torch | |||||
import json | import json | ||||
import os | import os | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader | |||||
from fastNLP.io.dataset_loader import ConllLoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||||
from fastNLP.core import Const as C | from fastNLP.core import Const as C | ||||
import numpy as np | import numpy as np | ||||
@@ -50,13 +51,15 @@ def load_sst(path, files): | |||||
for sub in [True, False, False]] | for sub in [True, False, False]] | ||||
ds_list = [loader.load(os.path.join(path, fn)) | ds_list = [loader.load(os.path.join(path, fn)) | ||||
for fn, loader in zip(files, loaders)] | for fn, loader in zip(files, loaders)] | ||||
word_v = Vocabulary(min_freq=2) | |||||
word_v = Vocabulary(min_freq=0) | |||||
tag_v = Vocabulary(unknown=None, padding=None) | tag_v = Vocabulary(unknown=None, padding=None) | ||||
for ds in ds_list: | for ds in ds_list: | ||||
ds.apply(lambda x: [w.lower() | ds.apply(lambda x: [w.lower() | ||||
for w in x['words']], new_field_name='words') | for w in x['words']], new_field_name='words') | ||||
ds_list[0].drop(lambda x: len(x['words']) < 3) | |||||
#ds_list[0].drop(lambda x: len(x['words']) < 3) | |||||
update_v(word_v, ds_list[0], 'words') | update_v(word_v, ds_list[0], 'words') | ||||
update_v(word_v, ds_list[1], 'words') | |||||
update_v(word_v, ds_list[2], 'words') | |||||
ds_list[0].apply(lambda x: tag_v.add_word( | ds_list[0].apply(lambda x: tag_v.add_word( | ||||
x['target']), new_field_name=None) | x['target']), new_field_name=None) | ||||
@@ -151,7 +154,10 @@ class EmbedLoader: | |||||
# some words from vocab are missing in pre-trained embedding | # some words from vocab are missing in pre-trained embedding | ||||
# we normally sample each dimension | # we normally sample each dimension | ||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | vocab_embed = embedding_matrix[np.where(hit_flags)] | ||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
#sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
# size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||||
sampled_vectors = np.random.uniform(-0.01, 0.01, | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | size=(len(vocab) - np.sum(hit_flags), emb_dim)) | ||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | ||||
return embedding_matrix | return embedding_matrix |
@@ -1,5 +1,5 @@ | |||||
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | #python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | ||||
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | #python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | ||||
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & | |||||
python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log & | |||||
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | #python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | ||||
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & | |||||
#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & |
@@ -1,4 +1,6 @@ | |||||
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | ||||
seed = set_rng_seeds(15360) | |||||
print('RNG SEED {}'.format(seed)) | |||||
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch | import torch | ||||
@@ -7,8 +9,8 @@ import fastNLP as FN | |||||
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
import sys | import sys | ||||
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||||
#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||||
pre_dir = '/home/ec2-user/fast_data/' | |||||
g_model_select = { | g_model_select = { | ||||
'pos': STSeqLabel, | 'pos': STSeqLabel, | ||||
@@ -17,8 +19,8 @@ g_model_select = { | |||||
'nli': STNLICls, | 'nli': STNLICls, | ||||
} | } | ||||
g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', | |||||
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} | |||||
g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt', | |||||
'zh': pre_dir + 'cc.zh.300.vec'} | |||||
g_args = None | g_args = None | ||||
g_model_cfg = None | g_model_cfg = None | ||||
@@ -53,7 +55,7 @@ def get_conll2012_ner(): | |||||
def get_sst(): | def get_sst(): | ||||
path = '/remote-home/yfshao/workdir/datasets/SST' | |||||
path = pre_dir + 'sst' | |||||
files = ['train.txt', 'dev.txt', 'test.txt'] | files = ['train.txt', 'dev.txt', 'test.txt'] | ||||
return load_sst(path, files) | return load_sst(path, files) | ||||
@@ -94,6 +96,7 @@ class MyCallback(FN.core.callback.Callback): | |||||
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | ||||
def on_step_end(self): | def on_step_end(self): | ||||
return | |||||
warm_steps = 6000 | warm_steps = 6000 | ||||
# learning rate warm-up & decay | # learning rate warm-up & decay | ||||
if self.step <= warm_steps: | if self.step <= warm_steps: | ||||
@@ -108,12 +111,11 @@ class MyCallback(FN.core.callback.Callback): | |||||
def train(): | def train(): | ||||
seed = set_rng_seeds(1234) | |||||
print('RNG SEED {}'.format(seed)) | |||||
print('loading data') | print('loading data') | ||||
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | ||||
g_args.ds, g_args.task)]() | g_args.ds, g_args.task)]() | ||||
print(ds_list[0][:2]) | print(ds_list[0][:2]) | ||||
print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2])) | |||||
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | ||||
g_model_cfg['num_cls'] = len(tag_v) | g_model_cfg['num_cls'] = len(tag_v) | ||||
print(g_model_cfg) | print(g_model_cfg) | ||||
@@ -123,11 +125,14 @@ def train(): | |||||
def init_model(model): | def init_model(model): | ||||
for p in model.parameters(): | for p in model.parameters(): | ||||
if p.size(0) != len(word_v): | if p.size(0) != len(word_v): | ||||
nn.init.normal_(p, 0.0, 0.05) | |||||
if len(p.size())<2: | |||||
nn.init.constant_(p, 0.0) | |||||
else: | |||||
nn.init.normal_(p, 0.0, 0.05) | |||||
init_model(model) | init_model(model) | ||||
train_data = ds_list[0] | train_data = ds_list[0] | ||||
dev_data = ds_list[2] | |||||
test_data = ds_list[1] | |||||
dev_data = ds_list[1] | |||||
test_data = ds_list[2] | |||||
print(tag_v.word2idx) | print(tag_v.word2idx) | ||||
if g_args.task in ['pos', 'ner']: | if g_args.task in ['pos', 'ner']: | ||||
@@ -145,14 +150,26 @@ def train(): | |||||
} | } | ||||
metric_key, metric = metrics[g_args.task] | metric_key, metric = metrics[g_args.task] | ||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||
ex_param = [x for x in model.parameters( | |||||
) if x.requires_grad and x.size(0) != len(word_v)] | |||||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||||
device=device, callbacks=[MyCallback()]) | |||||
params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)] | |||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | |||||
print([n for n,p in params]) | |||||
optim_cfg = [ | |||||
#{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||||
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay}, | |||||
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay} | |||||
] | |||||
print(model) | |||||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=loss, metrics=metric, metric_key=metric_key, | |||||
optimizer=torch.optim.Adam(optim_cfg), | |||||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000, | |||||
device=device, | |||||
use_tqdm=False, prefetch=False, | |||||
save_path=g_args.log, | |||||
sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN), | |||||
callbacks=[MyCallback()]) | |||||
trainer.train() | trainer.train() | ||||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | tester = FN.Tester(data=test_data, model=model, metrics=metric, | ||||
@@ -195,12 +212,12 @@ def main(): | |||||
'init_embed': (None, 300), | 'init_embed': (None, 300), | ||||
'num_cls': None, | 'num_cls': None, | ||||
'hidden_size': g_args.hidden, | 'hidden_size': g_args.hidden, | ||||
'num_layers': 4, | |||||
'num_layers': 2, | |||||
'num_head': g_args.nhead, | 'num_head': g_args.nhead, | ||||
'head_dim': g_args.hdim, | 'head_dim': g_args.hdim, | ||||
'max_len': MAX_LEN, | 'max_len': MAX_LEN, | ||||
'cls_hidden_size': 600, | |||||
'emb_dropout': 0.3, | |||||
'cls_hidden_size': 200, | |||||
'emb_dropout': g_args.drop, | |||||
'dropout': g_args.drop, | 'dropout': g_args.drop, | ||||
} | } | ||||
run_select[g_args.mode.lower()]() | run_select[g_args.mode.lower()]() | ||||
@@ -0,0 +1,68 @@ | |||||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||||
from fastNLP.io.file_reader import _read_json | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.base_loader import DataInfo | |||||
from reproduction.coreference_resolution.model.config import Config | |||||
import reproduction.coreference_resolution.model.preprocess as preprocess | |||||
class CRLoader(JsonLoader): | |||||
def __init__(self, fields=None, dropna=False): | |||||
super().__init__(fields, dropna) | |||||
def _load(self, path): | |||||
""" | |||||
加载数据 | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = DataSet() | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
if self.fields: | |||||
ins = {self.fields[k]: v for k, v in d.items()} | |||||
else: | |||||
ins = d | |||||
dataset.append(Instance(**ins)) | |||||
return dataset | |||||
def process(self, paths, **kwargs): | |||||
data_info = DataInfo() | |||||
for name in ['train', 'test', 'dev']: | |||||
data_info.datasets[name] = self.load(paths[name]) | |||||
config = Config() | |||||
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') | |||||
vocab.build_vocab() | |||||
word2id = vocab.word2idx | |||||
char_dict = preprocess.get_char_dict(config.char_path) | |||||
data_info.vocabs = vocab | |||||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
for name, ds in data_info.datasets.items(): | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[0], | |||||
new_field_name='doc_np') | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[1], | |||||
new_field_name='char_index') | |||||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
config.max_sentences, is_train=name=='train')[2], | |||||
new_field_name='seq_len') | |||||
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), | |||||
new_field_name='speaker_ids_np') | |||||
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') | |||||
ds.set_ignore_type('clusters') | |||||
ds.set_padder('clusters', None) | |||||
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") | |||||
ds.set_target("clusters") | |||||
# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) | |||||
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) | |||||
return data_info | |||||
@@ -0,0 +1,54 @@ | |||||
class Config(): | |||||
def __init__(self): | |||||
self.is_training = True | |||||
# path | |||||
self.glove = 'data/glove.840B.300d.txt.filtered' | |||||
self.turian = 'data/turian.50d.txt' | |||||
self.train_path = "data/train.english.jsonlines" | |||||
self.dev_path = "data/dev.english.jsonlines" | |||||
self.test_path = "data/test.english.jsonlines" | |||||
self.char_path = "data/char_vocab.english.txt" | |||||
self.cuda = "0" | |||||
self.max_word = 1500 | |||||
self.epoch = 200 | |||||
# config | |||||
# self.use_glove = True | |||||
# self.use_turian = True #No | |||||
self.use_elmo = False | |||||
self.use_CNN = True | |||||
self.model_heads = True #Yes | |||||
self.use_width = True # Yes | |||||
self.use_distance = True #Yes | |||||
self.use_metadata = True #Yes | |||||
self.mention_ratio = 0.4 | |||||
self.max_sentences = 50 | |||||
self.span_width = 10 | |||||
self.feature_size = 20 #宽度信息emb的size | |||||
self.lr = 0.001 | |||||
self.lr_decay = 1e-3 | |||||
self.max_antecedents = 100 # 这个参数在mention detection中没有用 | |||||
self.atten_hidden_size = 150 | |||||
self.mention_hidden_size = 150 | |||||
self.sa_hidden_size = 150 | |||||
self.char_emb_size = 8 | |||||
self.filter = [3,4,5] | |||||
# decay = 1e-5 | |||||
def __str__(self): | |||||
d = self.__dict__ | |||||
out = 'config==============\n' | |||||
for i in list(d): | |||||
out += i+":" | |||||
out += str(d[i])+"\n" | |||||
out+="config==============\n" | |||||
return out | |||||
if __name__=="__main__": | |||||
config = Config() | |||||
print(config) |
@@ -0,0 +1,163 @@ | |||||
from fastNLP.core.metrics import MetricBase | |||||
import numpy as np | |||||
from collections import Counter | |||||
from sklearn.utils.linear_assignment_ import linear_assignment | |||||
""" | |||||
Mostly borrowed from https://github.com/clarkkev/deep-coref/blob/master/evaluation.py | |||||
""" | |||||
class CRMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] | |||||
# TODO 改名为evaluate,输入也 | |||||
def evaluate(self, predicted, mention_to_predicted,clusters): | |||||
for e in self.evaluators: | |||||
e.update(predicted,mention_to_predicted, clusters) | |||||
def get_f1(self): | |||||
return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) | |||||
def get_recall(self): | |||||
return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) | |||||
def get_precision(self): | |||||
return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) | |||||
# TODO 原本的getprf | |||||
def get_metric(self,reset=False): | |||||
res = {"pre":self.get_precision(), "rec":self.get_recall(), "f":self.get_f1()} | |||||
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] | |||||
return res | |||||
class Evaluator(): | |||||
def __init__(self, metric, beta=1): | |||||
self.p_num = 0 | |||||
self.p_den = 0 | |||||
self.r_num = 0 | |||||
self.r_den = 0 | |||||
self.metric = metric | |||||
self.beta = beta | |||||
def update(self, predicted,mention_to_predicted,gold): | |||||
gold = gold[0].tolist() | |||||
gold = [tuple(tuple(m) for m in gc) for gc in gold] | |||||
mention_to_gold = {} | |||||
for gc in gold: | |||||
for mention in gc: | |||||
mention_to_gold[mention] = gc | |||||
if self.metric == ceafe: | |||||
pn, pd, rn, rd = self.metric(predicted, gold) | |||||
else: | |||||
pn, pd = self.metric(predicted, mention_to_gold) | |||||
rn, rd = self.metric(gold, mention_to_predicted) | |||||
self.p_num += pn | |||||
self.p_den += pd | |||||
self.r_num += rn | |||||
self.r_den += rd | |||||
def get_f1(self): | |||||
return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) | |||||
def get_recall(self): | |||||
return 0 if self.r_num == 0 else self.r_num / float(self.r_den) | |||||
def get_precision(self): | |||||
return 0 if self.p_num == 0 else self.p_num / float(self.p_den) | |||||
def get_prf(self): | |||||
return self.get_precision(), self.get_recall(), self.get_f1() | |||||
def get_counts(self): | |||||
return self.p_num, self.p_den, self.r_num, self.r_den | |||||
def b_cubed(clusters, mention_to_gold): | |||||
num, dem = 0, 0 | |||||
for c in clusters: | |||||
if len(c) == 1: | |||||
continue | |||||
gold_counts = Counter() | |||||
correct = 0 | |||||
for m in c: | |||||
if m in mention_to_gold: | |||||
gold_counts[tuple(mention_to_gold[m])] += 1 | |||||
for c2, count in gold_counts.items(): | |||||
if len(c2) != 1: | |||||
correct += count * count | |||||
num += correct / float(len(c)) | |||||
dem += len(c) | |||||
return num, dem | |||||
def muc(clusters, mention_to_gold): | |||||
tp, p = 0, 0 | |||||
for c in clusters: | |||||
p += len(c) - 1 | |||||
tp += len(c) | |||||
linked = set() | |||||
for m in c: | |||||
if m in mention_to_gold: | |||||
linked.add(mention_to_gold[m]) | |||||
else: | |||||
tp -= 1 | |||||
tp -= len(linked) | |||||
return tp, p | |||||
def phi4(c1, c2): | |||||
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) | |||||
def ceafe(clusters, gold_clusters): | |||||
clusters = [c for c in clusters if len(c) != 1] | |||||
scores = np.zeros((len(gold_clusters), len(clusters))) | |||||
for i in range(len(gold_clusters)): | |||||
for j in range(len(clusters)): | |||||
scores[i, j] = phi4(gold_clusters[i], clusters[j]) | |||||
matching = linear_assignment(-scores) | |||||
similarity = sum(scores[matching[:, 0], matching[:, 1]]) | |||||
return similarity, len(clusters), similarity, len(gold_clusters) | |||||
def lea(clusters, mention_to_gold): | |||||
num, dem = 0, 0 | |||||
for c in clusters: | |||||
if len(c) == 1: | |||||
continue | |||||
common_links = 0 | |||||
all_links = len(c) * (len(c) - 1) / 2.0 | |||||
for i, m in enumerate(c): | |||||
if m in mention_to_gold: | |||||
for m2 in c[i + 1:]: | |||||
if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: | |||||
common_links += 1 | |||||
num += len(c) * common_links / float(all_links) | |||||
dem += len(c) | |||||
return num, dem | |||||
def f1(p_num, p_den, r_num, r_den, beta=1): | |||||
p = 0 if p_den == 0 else p_num / float(p_den) | |||||
r = 0 if r_den == 0 else r_num / float(r_den) | |||||
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) |
@@ -0,0 +1,576 @@ | |||||
import torch | |||||
import numpy as np | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from allennlp.commands.elmo import ElmoEmbedder | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from reproduction.coreference_resolution.model import preprocess | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
import random | |||||
# 设置seed | |||||
torch.manual_seed(0) # cpu | |||||
torch.cuda.manual_seed(0) # gpu | |||||
np.random.seed(0) # numpy | |||||
random.seed(0) | |||||
class ffnn(nn.Module): | |||||
def __init__(self, input_size, hidden_size, output_size): | |||||
super(ffnn, self).__init__() | |||||
self.f = nn.Sequential( | |||||
# 多少层数 | |||||
nn.Linear(input_size, hidden_size), | |||||
nn.ReLU(inplace=True), | |||||
nn.Dropout(p=0.2), | |||||
nn.Linear(hidden_size, hidden_size), | |||||
nn.ReLU(inplace=True), | |||||
nn.Dropout(p=0.2), | |||||
nn.Linear(hidden_size, output_size) | |||||
) | |||||
self.reset_param() | |||||
def reset_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if param.dim() > 1: | |||||
nn.init.xavier_normal_(param) | |||||
# param.data = torch.tensor(np.random.randn(*param.shape)).float() | |||||
else: | |||||
nn.init.zeros_(param) | |||||
def forward(self, input): | |||||
return self.f(input).squeeze() | |||||
class Model(BaseModel): | |||||
def __init__(self, vocab, config): | |||||
word2id = vocab.word2idx | |||||
super(Model, self).__init__() | |||||
vocab_num = len(word2id) | |||||
self.word2id = word2id | |||||
self.config = config | |||||
self.char_dict = preprocess.get_char_dict('data/char_vocab.english.txt') | |||||
self.genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
self.device = torch.device("cuda:" + config.cuda) | |||||
self.emb = nn.Embedding(vocab_num, 350) | |||||
emb1 = EmbedLoader().load_with_vocab(config.glove, vocab,normalize=False) | |||||
emb2 = EmbedLoader().load_with_vocab(config.turian, vocab ,normalize=False) | |||||
pre_emb = np.concatenate((emb1, emb2), axis=1) | |||||
pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12) | |||||
if pre_emb is not None: | |||||
self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float()) | |||||
for param in self.emb.parameters(): | |||||
param.requires_grad = False | |||||
self.emb_dropout = nn.Dropout(inplace=True) | |||||
if config.use_elmo: | |||||
self.elmo = ElmoEmbedder(options_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json', | |||||
weight_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', | |||||
cuda_device=int(config.cuda)) | |||||
print("elmo load over.") | |||||
self.elmo_args = torch.randn((3), requires_grad=True).to(self.device) | |||||
self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size) | |||||
self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3) | |||||
self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4) | |||||
self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5) | |||||
self.feature_emb = nn.Embedding(config.span_width, config.feature_size) | |||||
self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True) | |||||
self.mention_distance_emb = nn.Embedding(10, config.feature_size) | |||||
self.distance_drop = nn.Dropout(p=0.2, inplace=True) | |||||
self.genre_emb = nn.Embedding(7, config.feature_size) | |||||
self.speaker_emb = nn.Embedding(2, config.feature_size) | |||||
self.bilstm = VarLSTM(input_size=350+150*config.use_CNN+config.use_elmo*1024,hidden_size=200,bidirectional=True,batch_first=True,hidden_dropout=0.2) | |||||
# self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True) | |||||
self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) | |||||
self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) | |||||
self.bilstm_drop = nn.Dropout(p=0.2, inplace=True) | |||||
self.atten = ffnn(input_size=400, hidden_size=config.atten_hidden_size, output_size=1) | |||||
self.mention_score = ffnn(input_size=1320, hidden_size=config.mention_hidden_size, output_size=1) | |||||
self.sa = ffnn(input_size=3980+40*config.use_metadata, hidden_size=config.sa_hidden_size, output_size=1) | |||||
self.mention_start_np = None | |||||
self.mention_end_np = None | |||||
def _reorder_lstm(self, word_emb, seq_lens): | |||||
sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True) | |||||
seq_lens_re = [seq_lens[i] for i in sort_ind] | |||||
emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True) | |||||
packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq, seq_lens_re, batch_first=True) | |||||
h0 = self.h0.repeat(1, len(seq_lens), 1) | |||||
c0 = self.c0.repeat(1, len(seq_lens), 1) | |||||
packed_out, final_states = self.bilstm(packed_seq, (h0, c0)) | |||||
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) | |||||
back_map = {ind: i for i, ind in enumerate(sort_ind)} | |||||
reorder_ind = [back_map[i] for i in range(len(seq_lens_re))] | |||||
lstm_out = self.reorder_sequence(lstm_out, reorder_ind, batch_first=True) | |||||
return lstm_out | |||||
def reorder_sequence(self, sequence_emb, order, batch_first=True): | |||||
""" | |||||
sequence_emb: [T, B, D] if not batch_first | |||||
order: list of sequence length | |||||
""" | |||||
batch_dim = 0 if batch_first else 1 | |||||
assert len(order) == sequence_emb.size()[batch_dim] | |||||
order = torch.LongTensor(order) | |||||
order = order.to(sequence_emb).long() | |||||
sorted_ = sequence_emb.index_select(index=order, dim=batch_dim) | |||||
del order | |||||
return sorted_ | |||||
def flat_lstm(self, lstm_out, seq_lens): | |||||
batch = lstm_out.shape[0] | |||||
seq = lstm_out.shape[1] | |||||
dim = lstm_out.shape[2] | |||||
l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)] | |||||
flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device)) | |||||
return flatted | |||||
def potential_mention_index(self, word_index, max_sent_len): | |||||
# get mention index [3,2]:the first sentence is 3 and secend 2 | |||||
# [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2) | |||||
potential_mention = [] | |||||
for i in range(len(word_index)): | |||||
for j in range(i, i + max_sent_len): | |||||
if (j < len(word_index) and word_index[i] == word_index[j]): | |||||
potential_mention.append([i, j]) | |||||
return potential_mention | |||||
def get_mention_start_end(self, seq_lens): | |||||
# 序列长度转换成mention | |||||
# [3,2] --> [0,0,0,1,1] | |||||
word_index = [0] * sum(seq_lens) | |||||
sent_index = 0 | |||||
index = 0 | |||||
for length in seq_lens: | |||||
for l in range(length): | |||||
word_index[index] = sent_index | |||||
index += 1 | |||||
sent_index += 1 | |||||
# [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....] | |||||
mention_id = self.potential_mention_index(word_index, self.config.span_width) | |||||
mention_start = np.array(mention_id, dtype=int)[:, 0] | |||||
mention_end = np.array(mention_id, dtype=int)[:, 1] | |||||
return mention_start, mention_end | |||||
def get_mention_emb(self, flatten_lstm, mention_start, mention_end): | |||||
mention_start_tensor = torch.from_numpy(mention_start).to(self.device) | |||||
mention_end_tensor = torch.from_numpy(mention_end).to(self.device) | |||||
emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor) # [mention_num,embed] | |||||
emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor) # [mention_num,embed] | |||||
return emb_start, emb_end | |||||
def get_mask(self, mention_start, mention_end): | |||||
# big mask for attention | |||||
mention_num = mention_start.shape[0] | |||||
mask = np.zeros((mention_num, self.config.span_width)) # [mention_num,span_width] | |||||
for i in range(mention_num): | |||||
start = mention_start[i] | |||||
end = mention_end[i] | |||||
# 实际上是宽度 | |||||
for j in range(end - start + 1): | |||||
mask[i][j] = 1 | |||||
mask = torch.from_numpy(mask) # [mention_num,max_mention] | |||||
# 0-->-inf 1-->0 | |||||
log_mask = torch.log(mask) | |||||
return log_mask | |||||
def get_mention_index(self, mention_start, max_mention): | |||||
# TODO 后面可能要改 | |||||
assert len(mention_start.shape) == 1 | |||||
mention_start_tensor = torch.from_numpy(mention_start) | |||||
num_mention = mention_start_tensor.shape[0] | |||||
mention_index = mention_start_tensor.expand(max_mention, num_mention).transpose(0, | |||||
1) # [num_mention,max_mention] | |||||
assert mention_index.shape[0] == num_mention | |||||
assert mention_index.shape[1] == max_mention | |||||
range_add = torch.arange(0, max_mention).expand(num_mention, max_mention).long() # [num_mention,max_mention] | |||||
mention_index = mention_index + range_add | |||||
mention_index = torch.min(mention_index, torch.LongTensor([mention_start[-1]]).expand(num_mention, max_mention)) | |||||
return mention_index.to(self.device) | |||||
def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens): | |||||
# 排序记录,高分段在前面 | |||||
mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True) | |||||
preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens)) | |||||
mention_ids = mention_ids[0:preserve_mention_num] | |||||
mention_score = mention_score[0:preserve_mention_num] | |||||
mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0, | |||||
index=mention_ids) # [lamda*word_num] | |||||
mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0, | |||||
index=mention_ids) # [lamda*word_num] | |||||
mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0) # [lamda*word_num,emb] | |||||
assert mention_score.shape[0] == preserve_mention_num | |||||
assert mention_start_tensor.shape[0] == preserve_mention_num | |||||
assert mention_end_tensor.shape[0] == preserve_mention_num | |||||
assert mention_emb.shape[0] == preserve_mention_num | |||||
# TODO 不交叉没做处理 | |||||
# 对start进行再排序,实际位置在前面 | |||||
# TODO 这里只考虑了start没有考虑end | |||||
mention_start_tensor, temp_index = torch.sort(mention_start_tensor) | |||||
mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index) | |||||
mention_emb = mention_emb.index_select(dim=0, index=temp_index) | |||||
mention_score = mention_score.index_select(dim=0, index=temp_index) | |||||
return mention_start_tensor, mention_end_tensor, mention_score, mention_emb | |||||
def get_antecedents(self, mention_starts, max_antecedents): | |||||
num_mention = mention_starts.shape[0] | |||||
max_antecedents = min(max_antecedents, num_mention) | |||||
# mention和它是第几个mention之间的对应关系 | |||||
antecedents = np.zeros((num_mention, max_antecedents), dtype=int) # [num_mention,max_an] | |||||
# 记录长度 | |||||
antecedents_len = [0] * num_mention | |||||
for i in range(num_mention): | |||||
ante_count = 0 | |||||
for j in range(max(0, i - max_antecedents), i): | |||||
antecedents[i, ante_count] = j | |||||
ante_count += 1 | |||||
# 补位操作 | |||||
for j in range(ante_count, max_antecedents): | |||||
antecedents[i, j] = 0 | |||||
antecedents_len[i] = ante_count | |||||
assert antecedents.shape[1] == max_antecedents | |||||
return antecedents, antecedents_len | |||||
def get_antecedents_score(self, span_represent, mention_score, antecedents, antecedents_len, mention_speakers_ids, | |||||
genre): | |||||
num_mention = mention_score.shape[0] | |||||
max_antecedent = antecedents.shape[1] | |||||
pair_emb = self.get_pair_emb(span_represent, antecedents, mention_speakers_ids, genre) # [span_num,max_ant,emb] | |||||
antecedent_scores = self.sa(pair_emb) | |||||
mask01 = self.sequence_mask(antecedents_len, max_antecedent) | |||||
maskinf = torch.log(mask01).to(self.device) | |||||
assert maskinf.shape[1] <= max_antecedent | |||||
assert antecedent_scores.shape[0] == num_mention | |||||
antecedent_scores = antecedent_scores + maskinf | |||||
antecedents = torch.from_numpy(antecedents).to(self.device) | |||||
mention_scoreij = mention_score.unsqueeze(1) + torch.gather( | |||||
mention_score.unsqueeze(0).expand(num_mention, num_mention), dim=1, index=antecedents) | |||||
antecedent_scores += mention_scoreij | |||||
antecedent_scores = torch.cat([torch.zeros([mention_score.shape[0], 1]).to(self.device), antecedent_scores], | |||||
1) # [num_mentions, max_ant + 1] | |||||
return antecedent_scores | |||||
############################## | |||||
def distance_bin(self, mention_distance): | |||||
bins = torch.zeros(mention_distance.size()).byte().to(self.device) | |||||
rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]] | |||||
for t, k in enumerate(rg): | |||||
i, j = k[0], k[1] | |||||
b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) | |||||
m1 = torch.ge(mention_distance, b) | |||||
e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) | |||||
m2 = torch.le(mention_distance, e) | |||||
bins = bins + (t + 1) * (m1 & m2) | |||||
return bins.long() | |||||
def get_distance_emb(self, antecedents_tensor): | |||||
num_mention = antecedents_tensor.shape[0] | |||||
max_ant = antecedents_tensor.shape[1] | |||||
assert max_ant <= self.config.max_antecedents | |||||
source = torch.arange(0, num_mention).expand(max_ant, num_mention).transpose(0,1).to(self.device) # [num_mention,max_ant] | |||||
mention_distance = source - antecedents_tensor | |||||
mention_distance_bin = self.distance_bin(mention_distance) | |||||
distance_emb = self.mention_distance_emb(mention_distance_bin) | |||||
distance_emb = self.distance_drop(distance_emb) | |||||
return distance_emb | |||||
def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre): | |||||
emb_dim = span_emb.shape[1] | |||||
num_span = span_emb.shape[0] | |||||
max_ant = antecedents.shape[1] | |||||
assert span_emb.shape[0] == antecedents.shape[0] | |||||
antecedents = torch.from_numpy(antecedents).to(self.device) | |||||
# [num_span,max_ant,emb] | |||||
antecedent_emb = torch.gather(span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim), dim=1, | |||||
index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim)) | |||||
# [num_span,max_ant,emb] | |||||
target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim)) | |||||
target_emb_tiled = target_emb_tiled.transpose(0, 1) | |||||
similarity_emb = antecedent_emb * target_emb_tiled | |||||
pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb] | |||||
# get speakers and genre | |||||
if self.config.use_metadata: | |||||
antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(num_span, num_span).gather(dim=1, | |||||
index=antecedents) | |||||
same_speaker = torch.eq(mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant), | |||||
antecedent_speaker_ids) # [num_mention,max_ant] | |||||
speaker_embedding = self.speaker_emb(same_speaker.long().to(self.device)) # [mention_num.max_ant,emb] | |||||
genre_embedding = self.genre_emb( | |||||
torch.LongTensor([genre]).expand(num_span, max_ant).to(self.device)) # [mention_num,max_ant,emb] | |||||
pair_emb_list.append(speaker_embedding) | |||||
pair_emb_list.append(genre_embedding) | |||||
# get distance emb | |||||
if self.config.use_distance: | |||||
distance_emb = self.get_distance_emb(antecedents) | |||||
pair_emb_list.append(distance_emb) | |||||
pair_emb = torch.cat(pair_emb_list, 2) | |||||
return pair_emb | |||||
def sequence_mask(self, len_list, max_len): | |||||
x = np.zeros((len(len_list), max_len)) | |||||
for i in range(len(len_list)): | |||||
l = len_list[i] | |||||
for j in range(l): | |||||
x[i][j] = 1 | |||||
return torch.from_numpy(x).float() | |||||
def logsumexp(self, value, dim=None, keepdim=False): | |||||
"""Numerically stable implementation of the operation | |||||
value.exp().sum(dim, keepdim).log() | |||||
""" | |||||
# TODO: torch.max(value, dim=None) threw an error at time of writing | |||||
if dim is not None: | |||||
m, _ = torch.max(value, dim=dim, keepdim=True) | |||||
value0 = value - m | |||||
if keepdim is False: | |||||
m = m.squeeze(dim) | |||||
return m + torch.log(torch.sum(torch.exp(value0), | |||||
dim=dim, keepdim=keepdim)) | |||||
else: | |||||
m = torch.max(value) | |||||
sum_exp = torch.sum(torch.exp(value - m)) | |||||
return m + torch.log(sum_exp) | |||||
def softmax_loss(self, antecedent_scores, antecedent_labels): | |||||
antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(self.device) | |||||
gold_scores = antecedent_scores + torch.log(antecedent_labels.float()) # [num_mentions, max_ant + 1] | |||||
marginalized_gold_scores = self.logsumexp(gold_scores, 1) # [num_mentions] | |||||
log_norm = self.logsumexp(antecedent_scores, 1) # [num_mentions] | |||||
return torch.sum(log_norm - marginalized_gold_scores) # [num_mentions]reduce_logsumexp | |||||
def get_predicted_antecedents(self, antecedents, antecedent_scores): | |||||
predicted_antecedents = [] | |||||
for i, index in enumerate(np.argmax(antecedent_scores.detach(), axis=1) - 1): | |||||
if index < 0: | |||||
predicted_antecedents.append(-1) | |||||
else: | |||||
predicted_antecedents.append(antecedents[i, index]) | |||||
return predicted_antecedents | |||||
def get_predicted_clusters(self, mention_starts, mention_ends, predicted_antecedents): | |||||
mention_to_predicted = {} | |||||
predicted_clusters = [] | |||||
for i, predicted_index in enumerate(predicted_antecedents): | |||||
if predicted_index < 0: | |||||
continue | |||||
assert i > predicted_index | |||||
predicted_antecedent = (int(mention_starts[predicted_index]), int(mention_ends[predicted_index])) | |||||
if predicted_antecedent in mention_to_predicted: | |||||
predicted_cluster = mention_to_predicted[predicted_antecedent] | |||||
else: | |||||
predicted_cluster = len(predicted_clusters) | |||||
predicted_clusters.append([predicted_antecedent]) | |||||
mention_to_predicted[predicted_antecedent] = predicted_cluster | |||||
mention = (int(mention_starts[i]), int(mention_ends[i])) | |||||
predicted_clusters[predicted_cluster].append(mention) | |||||
mention_to_predicted[mention] = predicted_cluster | |||||
predicted_clusters = [tuple(pc) for pc in predicted_clusters] | |||||
mention_to_predicted = {m: predicted_clusters[i] for m, i in mention_to_predicted.items()} | |||||
return predicted_clusters, mention_to_predicted | |||||
def evaluate_coref(self, mention_starts, mention_ends, predicted_antecedents, gold_clusters, evaluator): | |||||
gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters] | |||||
mention_to_gold = {} | |||||
for gc in gold_clusters: | |||||
for mention in gc: | |||||
mention_to_gold[mention] = gc | |||||
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(mention_starts, mention_ends, | |||||
predicted_antecedents) | |||||
evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) | |||||
return predicted_clusters | |||||
def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): | |||||
""" | |||||
实际输入都是tensor | |||||
:param sentences: 句子,被fastNLP转化成了numpy, | |||||
:param doc_np: 被fastNLP转化成了Tensor | |||||
:param speaker_ids_np: 被fastNLP转化成了Tensor | |||||
:param genre: 被fastNLP转化成了Tensor | |||||
:param char_index: 被fastNLP转化成了Tensor | |||||
:param seq_len: 被fastNLP转化成了Tensor | |||||
:return: | |||||
""" | |||||
# change for fastNLP | |||||
sentences = sentences[0].tolist() | |||||
doc_tensor = doc_np[0] | |||||
speakers_tensor = speaker_ids_np[0] | |||||
genre = genre[0].item() | |||||
char_index = char_index[0] | |||||
seq_len = seq_len[0].cpu().numpy() | |||||
# 类型 | |||||
# doc_tensor = torch.from_numpy(doc_np).to(self.device) | |||||
# speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device) | |||||
mention_emb_list = [] | |||||
word_emb = self.emb(doc_tensor) | |||||
word_emb_list = [word_emb] | |||||
if self.config.use_CNN: | |||||
# [batch, length, char_length, char_dim] | |||||
char = self.char_emb(char_index) | |||||
char_size = char.size() | |||||
# first transform to [batch *length, char_length, char_dim] | |||||
# then transpose to [batch * length, char_dim, char_length] | |||||
char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2) | |||||
# put into cnn [batch*length, char_filters, char_length] | |||||
# then put into maxpooling [batch * length, char_filters] | |||||
char_over_cnn, _ = self.conv1(char).max(dim=2) | |||||
# reshape to [batch, length, char_filters] | |||||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||||
word_emb_list.append(char_over_cnn) | |||||
char_over_cnn, _ = self.conv2(char).max(dim=2) | |||||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||||
word_emb_list.append(char_over_cnn) | |||||
char_over_cnn, _ = self.conv3(char).max(dim=2) | |||||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||||
word_emb_list.append(char_over_cnn) | |||||
# word_emb = torch.cat(word_emb_list, dim=2) | |||||
# use elmo or not | |||||
if self.config.use_elmo: | |||||
# 如果确实被截断了 | |||||
if doc_tensor.shape[0] == 50 and len(sentences) > 50: | |||||
sentences = sentences[0:50] | |||||
elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(sentences) | |||||
elmo_embedding = elmo_embedding.to( | |||||
self.device) # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024] | |||||
elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \ | |||||
self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2] | |||||
word_emb_list.append(elmo_embedding) | |||||
# print(word_emb_list[0].shape) | |||||
# print(word_emb_list[1].shape) | |||||
# print(word_emb_list[2].shape) | |||||
# print(word_emb_list[3].shape) | |||||
# print(word_emb_list[4].shape) | |||||
word_emb = torch.cat(word_emb_list, dim=2) | |||||
word_emb = self.emb_dropout(word_emb) | |||||
# word_emb_elmo = self.emb_dropout(word_emb_elmo) | |||||
lstm_out = self._reorder_lstm(word_emb, seq_len) | |||||
flatten_lstm = self.flat_lstm(lstm_out, seq_len) # [word_num,emb] | |||||
flatten_lstm = self.bilstm_drop(flatten_lstm) | |||||
# TODO 没有按照论文写 | |||||
flatten_word_emb = self.flat_lstm(word_emb, seq_len) # [word_num,emb] | |||||
mention_start, mention_end = self.get_mention_start_end(seq_len) # [mention_num] | |||||
self.mention_start_np = mention_start # [mention_num] np | |||||
self.mention_end_np = mention_end | |||||
mention_num = mention_start.shape[0] | |||||
emb_start, emb_end = self.get_mention_emb(flatten_lstm, mention_start, mention_end) # [mention_num,emb] | |||||
# list | |||||
mention_emb_list.append(emb_start) | |||||
mention_emb_list.append(emb_end) | |||||
if self.config.use_width: | |||||
mention_width_index = mention_end - mention_start | |||||
mention_width_tensor = torch.from_numpy(mention_width_index).to(self.device) # [mention_num] | |||||
mention_width_emb = self.feature_emb(mention_width_tensor) | |||||
mention_width_emb = self.feature_emb_dropout(mention_width_emb) | |||||
mention_emb_list.append(mention_width_emb) | |||||
if self.config.model_heads: | |||||
mention_index = self.get_mention_index(mention_start, self.config.span_width) # [mention_num,max_mention] | |||||
log_mask_tensor = self.get_mask(mention_start, mention_end).float().to( | |||||
self.device) # [mention_num,max_mention] | |||||
alpha = self.atten(flatten_lstm).to(self.device) # [word_num] | |||||
# 得到attention | |||||
mention_head_score = torch.gather(alpha.expand(mention_num, -1), 1, | |||||
mention_index).float().to(self.device) # [mention_num,max_mention] | |||||
mention_attention = F.softmax(mention_head_score + log_mask_tensor, dim=1) # [mention_num,max_mention] | |||||
# TODO flatte lstm | |||||
word_num = flatten_lstm.shape[0] | |||||
lstm_emb = flatten_lstm.shape[1] | |||||
emb_num = flatten_word_emb.shape[1] | |||||
# [num_mentions, max_mention_width, emb] | |||||
mention_text_emb = torch.gather( | |||||
flatten_word_emb.unsqueeze(1).expand(word_num, self.config.span_width, emb_num), | |||||
0, mention_index.unsqueeze(2).expand(mention_num, self.config.span_width, | |||||
emb_num)) | |||||
# [mention_num,emb] | |||||
mention_head_emb = torch.sum( | |||||
mention_attention.unsqueeze(2).expand(mention_num, self.config.span_width, emb_num) * mention_text_emb, | |||||
dim=1) | |||||
mention_emb_list.append(mention_head_emb) | |||||
candidate_mention_emb = torch.cat(mention_emb_list, 1) # [candidate_mention_num,emb] | |||||
candidate_mention_score = self.mention_score(candidate_mention_emb) # [candidate_mention_num] | |||||
antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (None, None, None, None) | |||||
mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \ | |||||
self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len) | |||||
mention_speakers_ids = speakers_tensor.index_select(dim=0, index=mention_start_tensor) # num_mention | |||||
antecedents, antecedents_len = self.get_antecedents(mention_start_tensor, self.config.max_antecedents) | |||||
antecedent_scores = self.get_antecedents_score(mention_emb, mention_score, antecedents, antecedents_len, | |||||
mention_speakers_ids, genre) | |||||
ans = {"candidate_mention_score": candidate_mention_score, "antecedent_scores": antecedent_scores, | |||||
"antecedents": antecedents, "mention_start_tensor": mention_start_tensor, | |||||
"mention_end_tensor": mention_end_tensor} | |||||
return ans | |||||
def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): | |||||
ans = self(sentences, | |||||
doc_np, | |||||
speaker_ids_np, | |||||
genre, | |||||
char_index, | |||||
seq_len) | |||||
predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"]) | |||||
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"], | |||||
ans["mention_end_tensor"], | |||||
predicted_antecedents) | |||||
return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} | |||||
if __name__ == '__main__': | |||||
pass |
@@ -0,0 +1,225 @@ | |||||
import json | |||||
import numpy as np | |||||
from . import util | |||||
import collections | |||||
def load(path): | |||||
""" | |||||
load the file from jsonline | |||||
:param path: | |||||
:return: examples with many example(dict): {"clusters":[[[mention],[mention]],[another cluster]], | |||||
"doc_key":"str","speakers":[[,,,],[]...],"sentence":[[][]]} | |||||
""" | |||||
with open(path) as f: | |||||
train_examples = [json.loads(jsonline) for jsonline in f.readlines()] | |||||
return train_examples | |||||
def get_vocab(): | |||||
""" | |||||
从所有的句子中得到最终的字典,被main调用,不止是train,还有dev和test | |||||
:param examples: | |||||
:return: word2id & id2word | |||||
""" | |||||
word2id = {'PAD':0,'UNK':1} | |||||
id2word = {0:'PAD',1:'UNK'} | |||||
index = 2 | |||||
data = [load("../data/train.english.jsonlines"),load("../data/dev.english.jsonlines"),load("../data/test.english.jsonlines")] | |||||
for examples in data: | |||||
for example in examples: | |||||
for sent in example["sentences"]: | |||||
for word in sent: | |||||
if(word not in word2id): | |||||
word2id[word]=index | |||||
id2word[index] = word | |||||
index += 1 | |||||
return word2id,id2word | |||||
def normalize(v): | |||||
norm = np.linalg.norm(v) | |||||
if norm > 0: | |||||
return v / norm | |||||
else: | |||||
return v | |||||
# 加载glove得到embedding | |||||
def get_emb(id2word,embedding_size): | |||||
glove_oov = 0 | |||||
turian_oov = 0 | |||||
both = 0 | |||||
glove_emb_path = "../data/glove.840B.300d.txt.filtered" | |||||
turian_emb_path = "../data/turian.50d.txt" | |||||
word_num = len(id2word) | |||||
emb = np.zeros((word_num,embedding_size)) | |||||
glove_emb_dict = util.load_embedding_dict(glove_emb_path,300,"txt") | |||||
turian_emb_dict = util.load_embedding_dict(turian_emb_path,50,"txt") | |||||
for i in range(word_num): | |||||
if id2word[i] in glove_emb_dict: | |||||
word_embedding = glove_emb_dict.get(id2word[i]) | |||||
emb[i][0:300] = np.array(word_embedding) | |||||
else: | |||||
# print(id2word[i]) | |||||
glove_oov += 1 | |||||
if id2word[i] in turian_emb_dict: | |||||
word_embedding = turian_emb_dict.get(id2word[i]) | |||||
emb[i][300:350] = np.array(word_embedding) | |||||
else: | |||||
# print(id2word[i]) | |||||
turian_oov += 1 | |||||
if id2word[i] not in glove_emb_dict and id2word[i] not in turian_emb_dict: | |||||
both += 1 | |||||
emb[i] = normalize(emb[i]) | |||||
print("embedding num:"+str(word_num)) | |||||
print("glove num:"+str(glove_oov)) | |||||
print("glove oov rate:"+str(glove_oov/word_num)) | |||||
print("turian num:"+str(turian_oov)) | |||||
print("turian oov rate:"+str(turian_oov/word_num)) | |||||
print("both num:"+str(both)) | |||||
return emb | |||||
def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): | |||||
max_len = 0 | |||||
max_word_length = 0 | |||||
docvex = [] | |||||
length = [] | |||||
if is_train: | |||||
sent_num = min(max_sentences,len(doc)) | |||||
else: | |||||
sent_num = len(doc) | |||||
for i in range(sent_num): | |||||
sent = doc[i] | |||||
length.append(len(sent)) | |||||
if (len(sent) > max_len): | |||||
max_len = len(sent) | |||||
sent_vec =[] | |||||
for j,word in enumerate(sent): | |||||
if len(word)>max_word_length: | |||||
max_word_length = len(word) | |||||
if word in word2id: | |||||
sent_vec.append(word2id[word]) | |||||
else: | |||||
sent_vec.append(word2id["UNK"]) | |||||
docvex.append(sent_vec) | |||||
char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) | |||||
for i in range(sent_num): | |||||
sent = doc[i] | |||||
for j,word in enumerate(sent): | |||||
char_index[i, j, :len(word)] = [char_dict[c] for c in word] | |||||
return docvex,char_index,length,max_len | |||||
# TODO 修改了接口,确认所有该修改的地方都修改好 | |||||
def doc2numpy(doc,word2id,chardict,max_filter,max_sentences,is_train): | |||||
docvec, char_index, length, max_len = _doc2vec(doc,word2id,chardict,max_filter,max_sentences,is_train) | |||||
assert max(length) == max_len | |||||
assert char_index.shape[0]==len(length) | |||||
assert char_index.shape[1]==max_len | |||||
doc_np = np.zeros((len(docvec), max_len), int) | |||||
for i in range(len(docvec)): | |||||
for j in range(len(docvec[i])): | |||||
doc_np[i][j] = docvec[i][j] | |||||
return doc_np,char_index,length | |||||
# TODO 没有测试 | |||||
def speaker2numpy(speakers_raw,max_sentences,is_train): | |||||
if is_train and len(speakers_raw)> max_sentences: | |||||
speakers_raw = speakers_raw[0:max_sentences] | |||||
speakers = flatten(speakers_raw) | |||||
speaker_dict = {s: i for i, s in enumerate(set(speakers))} | |||||
speaker_ids = np.array([speaker_dict[s] for s in speakers]) | |||||
return speaker_ids | |||||
def flat_cluster(clusters): | |||||
flatted = [] | |||||
for cluster in clusters: | |||||
for item in cluster: | |||||
flatted.append(item) | |||||
return flatted | |||||
def get_right_mention(clusters,mention_start_np,mention_end_np): | |||||
flatted = flat_cluster(clusters) | |||||
cluster_num = len(flatted) | |||||
mention_num = mention_start_np.shape[0] | |||||
right_mention = np.zeros(mention_num,dtype=int) | |||||
for i in range(mention_num): | |||||
if [mention_start_np[i],mention_end_np[i]] in flatted: | |||||
right_mention[i]=1 | |||||
return right_mention,cluster_num | |||||
def handle_cluster(clusters): | |||||
gold_mentions = sorted(tuple(m) for m in flatten(clusters)) | |||||
gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} | |||||
cluster_ids = np.zeros(len(gold_mentions), dtype=int) | |||||
for cluster_id, cluster in enumerate(clusters): | |||||
for mention in cluster: | |||||
cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id | |||||
gold_starts, gold_ends = tensorize_mentions(gold_mentions) | |||||
return cluster_ids, gold_starts, gold_ends | |||||
# 展平 | |||||
def flatten(l): | |||||
return [item for sublist in l for item in sublist] | |||||
# 把mention分成start end | |||||
def tensorize_mentions(mentions): | |||||
if len(mentions) > 0: | |||||
starts, ends = zip(*mentions) | |||||
else: | |||||
starts, ends = [], [] | |||||
return np.array(starts), np.array(ends) | |||||
def get_char_dict(path): | |||||
vocab = ["<UNK>"] | |||||
with open(path) as f: | |||||
vocab.extend(c.strip() for c in f.readlines()) | |||||
char_dict = collections.defaultdict(int) | |||||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||||
return char_dict | |||||
def get_labels(clusters,mention_starts,mention_ends,max_antecedents): | |||||
cluster_ids, gold_starts, gold_ends = handle_cluster(clusters) | |||||
num_mention = mention_starts.shape[0] | |||||
num_gold = gold_starts.shape[0] | |||||
max_antecedents = min(max_antecedents, num_mention) | |||||
mention_indices = {} | |||||
for i in range(num_mention): | |||||
mention_indices[(mention_starts[i].detach().item(), mention_ends[i].detach().item())] = i | |||||
# 用来记录哪些mention是对的,-1表示错误,正数代表这个mention实际上对应哪个gold cluster的id | |||||
mention_cluster_ids = [-1] * num_mention | |||||
# test | |||||
right_mention_count = 0 | |||||
for i in range(num_gold): | |||||
right_mention = mention_indices.get((gold_starts[i], gold_ends[i])) | |||||
if (right_mention != None): | |||||
right_mention_count += 1 | |||||
mention_cluster_ids[right_mention] = cluster_ids[i] | |||||
# i j 是否属于同一个cluster | |||||
labels = np.zeros((num_mention, max_antecedents + 1), dtype=bool) # [num_mention,max_an+1] | |||||
for i in range(num_mention): | |||||
ante_count = 0 | |||||
null_label = True | |||||
for j in range(max(0, i - max_antecedents), i): | |||||
if (mention_cluster_ids[i] >= 0 and mention_cluster_ids[i] == mention_cluster_ids[j]): | |||||
labels[i, ante_count + 1] = True | |||||
null_label = False | |||||
else: | |||||
labels[i, ante_count + 1] = False | |||||
ante_count += 1 | |||||
for j in range(ante_count, max_antecedents): | |||||
labels[i, j + 1] = False | |||||
labels[i, 0] = null_label | |||||
return labels | |||||
# test=========================== | |||||
if __name__=="__main__": | |||||
word2id,id2word = get_vocab() | |||||
get_emb(id2word,350) | |||||
@@ -0,0 +1,32 @@ | |||||
from fastNLP.core.losses import LossBase | |||||
from reproduction.coreference_resolution.model.preprocess import get_labels | |||||
from reproduction.coreference_resolution.model.config import Config | |||||
import torch | |||||
class SoftmaxLoss(LossBase): | |||||
""" | |||||
交叉熵loss | |||||
允许多标签分类 | |||||
""" | |||||
def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None): | |||||
""" | |||||
:param pred: | |||||
:param target: | |||||
""" | |||||
super().__init__() | |||||
self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters, | |||||
mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) | |||||
def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor): | |||||
antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor, | |||||
Config().max_antecedents) | |||||
antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) | |||||
gold_scores = antecedent_scores + torch.log(antecedent_labels.float()).to(torch.device("cuda:" + Config().cuda)) # [num_mentions, max_ant + 1] | |||||
marginalized_gold_scores = gold_scores.logsumexp(dim=1) # [num_mentions] | |||||
log_norm = antecedent_scores.logsumexp(dim=1) # [num_mentions] | |||||
return torch.sum(log_norm - marginalized_gold_scores) |
@@ -0,0 +1,101 @@ | |||||
import os | |||||
import errno | |||||
import collections | |||||
import torch | |||||
import numpy as np | |||||
import pyhocon | |||||
# flatten the list | |||||
def flatten(l): | |||||
return [item for sublist in l for item in sublist] | |||||
def get_config(filename): | |||||
return pyhocon.ConfigFactory.parse_file(filename) | |||||
# safe make directions | |||||
def mkdirs(path): | |||||
try: | |||||
os.makedirs(path) | |||||
except OSError as exception: | |||||
if exception.errno != errno.EEXIST: | |||||
raise | |||||
return path | |||||
def load_char_dict(char_vocab_path): | |||||
vocab = ["<unk>"] | |||||
with open(char_vocab_path) as f: | |||||
vocab.extend(c.strip() for c in f.readlines()) | |||||
char_dict = collections.defaultdict(int) | |||||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||||
return char_dict | |||||
# 加载embedding | |||||
def load_embedding_dict(embedding_path, embedding_size, embedding_format): | |||||
print("Loading word embeddings from {}...".format(embedding_path)) | |||||
default_embedding = np.zeros(embedding_size) | |||||
embedding_dict = collections.defaultdict(lambda: default_embedding) | |||||
skip_first = embedding_format == "vec" | |||||
with open(embedding_path) as f: | |||||
for i, line in enumerate(f.readlines()): | |||||
if skip_first and i == 0: | |||||
continue | |||||
splits = line.split() | |||||
assert len(splits) == embedding_size + 1 | |||||
word = splits[0] | |||||
embedding = np.array([float(s) for s in splits[1:]]) | |||||
embedding_dict[word] = embedding | |||||
print("Done loading word embeddings.") | |||||
return embedding_dict | |||||
# safe devide | |||||
def maybe_divide(x, y): | |||||
return 0 if y == 0 else x / float(y) | |||||
def shape(x, dim): | |||||
return x.get_shape()[dim].value or torch.shape(x)[dim] | |||||
def normalize(v): | |||||
norm = np.linalg.norm(v) | |||||
if norm > 0: | |||||
return v / norm | |||||
else: | |||||
return v | |||||
class RetrievalEvaluator(object): | |||||
def __init__(self): | |||||
self._num_correct = 0 | |||||
self._num_gold = 0 | |||||
self._num_predicted = 0 | |||||
def update(self, gold_set, predicted_set): | |||||
self._num_correct += len(gold_set & predicted_set) | |||||
self._num_gold += len(gold_set) | |||||
self._num_predicted += len(predicted_set) | |||||
def recall(self): | |||||
return maybe_divide(self._num_correct, self._num_gold) | |||||
def precision(self): | |||||
return maybe_divide(self._num_correct, self._num_predicted) | |||||
def metrics(self): | |||||
recall = self.recall() | |||||
precision = self.precision() | |||||
f1 = maybe_divide(2 * recall * precision, precision + recall) | |||||
return recall, precision, f1 | |||||
if __name__=="__main__": | |||||
print(load_char_dict("../data/char_vocab.english.txt")) | |||||
embedding_dict = load_embedding_dict("../data/glove.840B.300d.txt.filtered",300,"txt") | |||||
print("hello") |
@@ -0,0 +1,49 @@ | |||||
# 共指消解复现 | |||||
## 介绍 | |||||
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | |||||
对于涉及自然语言理解的许多更高级别的NLP任务来说, | |||||
这是一个重要的步骤,例如文档摘要,问题回答和信息提取。 | |||||
代码的实现主要基于[ End-to-End Coreference Resolution (Lee et al, 2017)](https://arxiv.org/pdf/1707.07045). | |||||
## 数据获取与预处理 | |||||
论文在[OntoNote5.0](https://allennlp.org/models)数据集上取得了当时的sota结果。 | |||||
由于版权问题,本文无法提供数据集的下载,请自行下载。 | |||||
原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | |||||
代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||||
处理之后的数据集为json格式,例子: | |||||
``` | |||||
{ | |||||
"clusters": [], | |||||
"doc_key": "nw", | |||||
"sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]], | |||||
"speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]] | |||||
} | |||||
``` | |||||
### embedding 数据集下载 | |||||
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | |||||
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||||
## 运行 | |||||
```python | |||||
# 训练代码 | |||||
CUDA_VISIBLE_DEVICES=0 python train.py | |||||
# 测试代码 | |||||
CUDA_VISIBLE_DEVICES=0 python valid.py | |||||
``` | |||||
## 结果 | |||||
原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | |||||
其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||||
在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||||
## 问题 | |||||
如果您有什么问题或者反馈,请提issue或者邮件联系我: | |||||
yexu_i@qq.com |
@@ -0,0 +1,14 @@ | |||||
import unittest | |||||
from ..data_load.cr_loader import CRLoader | |||||
class Test_CRLoader(unittest.TestCase): | |||||
def test_cr_loader(self): | |||||
train_path = 'data/train.english.jsonlines.mini' | |||||
dev_path = 'data/dev.english.jsonlines.minid' | |||||
test_path = 'data/test.english.jsonlines' | |||||
cr = CRLoader() | |||||
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) | |||||
print(data_info.datasets['train'][0]) | |||||
print(data_info.datasets['dev'][0]) | |||||
print(data_info.datasets['test'][0]) |
@@ -0,0 +1,69 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
import torch | |||||
from torch.optim import Adam | |||||
from fastNLP.core.callback import Callback, GradientClipCallback | |||||
from fastNLP.core.trainer import Trainer | |||||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
from reproduction.coreference_resolution.model.config import Config | |||||
from reproduction.coreference_resolution.model.model_re import Model | |||||
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | |||||
from reproduction.coreference_resolution.model.metric import CRMetric | |||||
from fastNLP import SequentialSampler | |||||
from fastNLP import cache_results | |||||
# torch.backends.cudnn.benchmark = False | |||||
# torch.backends.cudnn.deterministic = True | |||||
class LRCallback(Callback): | |||||
def __init__(self, parameters, decay_rate=1e-3): | |||||
super().__init__() | |||||
self.paras = parameters | |||||
self.decay_rate = decay_rate | |||||
def on_step_end(self): | |||||
if self.step % 100 == 0: | |||||
for para in self.paras: | |||||
para['lr'] = para['lr'] * (1 - self.decay_rate) | |||||
if __name__ == "__main__": | |||||
config = Config() | |||||
print(config) | |||||
@cache_results('cache.pkl') | |||||
def cache(): | |||||
cr_train_dev_test = CRLoader() | |||||
data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, | |||||
'test': config.test_path}) | |||||
return data_info | |||||
data_info = cache() | |||||
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | |||||
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | |||||
# print(data_info) | |||||
model = Model(data_info.vocabs, config) | |||||
print(model) | |||||
loss = SoftmaxLoss() | |||||
metric = CRMetric() | |||||
optim = Adam(model.parameters(), lr=config.lr) | |||||
lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | |||||
trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], | |||||
loss=loss, metrics=metric, check_code_level=-1,sampler=None, | |||||
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | |||||
optimizer=optim, | |||||
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', | |||||
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | |||||
print() | |||||
trainer.train() |
@@ -0,0 +1,24 @@ | |||||
import torch | |||||
from reproduction.coreference_resolution.model.config import Config | |||||
from reproduction.coreference_resolution.model.metric import CRMetric | |||||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
from fastNLP import Tester | |||||
import argparse | |||||
if __name__=='__main__': | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--path') | |||||
args = parser.parse_args() | |||||
cr_loader = CRLoader() | |||||
config = Config() | |||||
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, | |||||
'test': config.test_path}) | |||||
metirc = CRMetric() | |||||
model = torch.load(args.path) | |||||
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||||
tester.test() | |||||
print('test over') | |||||
@@ -0,0 +1,284 @@ | |||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||||
from fastNLP.io.dataset_loader import ConllLoader | |||||
import numpy as np | |||||
from itertools import chain | |||||
from fastNLP import DataSet, Vocabulary | |||||
from functools import partial | |||||
import os | |||||
from typing import Union, Dict | |||||
from reproduction.utils import check_dataloader_paths | |||||
class CTBxJointLoader(DataSetLoader): | |||||
""" | |||||
文件夹下应该具有以下的文件结构 | |||||
-train.conllx | |||||
-dev.conllx | |||||
-test.conllx | |||||
每个文件中的内容如下(空格隔开不同的句子, 共有) | |||||
1 费孝通 _ NR NR _ 3 nsubjpass _ _ | |||||
2 被 _ SB SB _ 3 pass _ _ | |||||
3 授予 _ VV VV _ 0 root _ _ | |||||
4 麦格赛赛 _ NR NR _ 5 nn _ _ | |||||
5 奖 _ NN NN _ 3 dobj _ _ | |||||
1 新华社 _ NR NR _ 7 dep _ _ | |||||
2 马尼拉 _ NR NR _ 7 dep _ _ | |||||
3 8月 _ NT NT _ 7 dep _ _ | |||||
4 31日 _ NT NT _ 7 dep _ _ | |||||
... | |||||
""" | |||||
def __init__(self): | |||||
self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) | |||||
def load(self, path:str): | |||||
""" | |||||
给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 | |||||
words: list[str] | |||||
pos_tags: list[str] | |||||
heads: list[int] | |||||
labels: list[str] | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = self._loader.load(path) | |||||
dataset.heads.int() | |||||
return dataset | |||||
def process(self, paths): | |||||
""" | |||||
:param paths: | |||||
:return: | |||||
Dataset包含以下的field | |||||
chars: | |||||
bigrams: | |||||
trigrams: | |||||
pre_chars: | |||||
pre_bigrams: | |||||
pre_trigrams: | |||||
seg_targets: | |||||
seg_masks: | |||||
seq_lens: | |||||
char_labels: | |||||
char_heads: | |||||
gold_word_pairs: | |||||
seg_targets: | |||||
seg_masks: | |||||
char_labels: | |||||
char_heads: | |||||
pun_masks: | |||||
gold_label_word_pairs: | |||||
""" | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
data.datasets[name] = dataset | |||||
char_labels_vocab = Vocabulary(padding=None, unknown=None) | |||||
def process(dataset, char_label_vocab): | |||||
dataset.apply(add_word_lst, new_field_name='word_lst') | |||||
dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') | |||||
dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') | |||||
dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') | |||||
dataset.apply(add_char_heads, new_field_name='char_heads') | |||||
dataset.apply(add_char_labels, new_field_name='char_labels') | |||||
dataset.apply(add_segs, new_field_name='seg_targets') | |||||
dataset.apply(add_mask, new_field_name='seg_masks') | |||||
dataset.add_seq_len('chars', new_field_name='seq_lens') | |||||
dataset.apply(add_pun_masks, new_field_name='pun_masks') | |||||
if len(char_label_vocab.word_count)==0: | |||||
char_label_vocab.from_dataset(dataset, field_name='char_labels') | |||||
char_label_vocab.index_dataset(dataset, field_name='char_labels') | |||||
new_dataset = add_root(dataset) | |||||
new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) | |||||
global add_label_word_pairs | |||||
add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) | |||||
new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) | |||||
new_dataset.set_pad_val('char_labels', -1) | |||||
new_dataset.set_pad_val('char_heads', -1) | |||||
return new_dataset | |||||
for name in list(paths.keys()): | |||||
dataset = data.datasets[name] | |||||
dataset = process(dataset, char_labels_vocab) | |||||
data.datasets[name] = dataset | |||||
data.vocabs['char_labels'] = char_labels_vocab | |||||
char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') | |||||
bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') | |||||
trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') | |||||
for name in ['chars', 'bigrams', 'trigrams']: | |||||
vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) | |||||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) | |||||
data.vocabs['pre_{}'.format(name)] = vocab | |||||
for name, vocab in zip(['chars', 'bigrams', 'trigrams'], | |||||
[char_vocab, bigram_vocab, trigram_vocab]): | |||||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) | |||||
data.vocabs[name] = vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', | |||||
'pre_bigrams', 'pre_trigrams') | |||||
dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', | |||||
'char_heads', | |||||
'pun_masks', 'gold_label_word_pairs') | |||||
return data | |||||
def add_label_word_pairs(instance, label_vocab): | |||||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||||
word_end_indexes.insert(0, 0) | |||||
word_pairs = [] | |||||
labels = instance['labels'] | |||||
pos_tags = instance['pos_tags'] | |||||
for idx, head in enumerate(instance['heads']): | |||||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||||
continue | |||||
label = label_vocab.to_index(labels[idx]) | |||||
if head==0: | |||||
word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||||
else: | |||||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, | |||||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||||
return word_pairs | |||||
def add_word_pairs(instance): | |||||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||||
word_end_indexes.insert(0, 0) | |||||
word_pairs = [] | |||||
pos_tags = instance['pos_tags'] | |||||
for idx, head in enumerate(instance['heads']): | |||||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||||
continue | |||||
if head==0: | |||||
word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||||
else: | |||||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), | |||||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||||
return word_pairs | |||||
def add_root(dataset): | |||||
new_dataset = DataSet() | |||||
for sample in dataset: | |||||
chars = ['char_root'] + sample['chars'] | |||||
bigrams = ['bigram_root'] + sample['bigrams'] | |||||
trigrams = ['trigram_root'] + sample['trigrams'] | |||||
seq_lens = sample['seq_lens']+1 | |||||
char_labels = [0] + sample['char_labels'] | |||||
char_heads = [0] + sample['char_heads'] | |||||
sample['chars'] = chars | |||||
sample['bigrams'] = bigrams | |||||
sample['trigrams'] = trigrams | |||||
sample['seq_lens'] = seq_lens | |||||
sample['char_labels'] = char_labels | |||||
sample['char_heads'] = char_heads | |||||
new_dataset.append(sample) | |||||
return new_dataset | |||||
def add_pun_masks(instance): | |||||
tags = instance['pos_tags'] | |||||
pun_masks = [] | |||||
for word, tag in zip(instance['words'], tags): | |||||
if tag=='PU': | |||||
pun_masks.extend([1]*len(word)) | |||||
else: | |||||
pun_masks.extend([0]*len(word)) | |||||
return pun_masks | |||||
def add_word_lst(instance): | |||||
words = instance['words'] | |||||
word_lst = [list(word) for word in words] | |||||
return word_lst | |||||
def add_bigram(instance): | |||||
chars = instance['chars'] | |||||
length = len(chars) | |||||
chars = chars + ['<eos>'] | |||||
bigrams = [] | |||||
for i in range(length): | |||||
bigrams.append(''.join(chars[i:i + 2])) | |||||
return bigrams | |||||
def add_trigram(instance): | |||||
chars = instance['chars'] | |||||
length = len(chars) | |||||
chars = chars + ['<eos>'] * 2 | |||||
trigrams = [] | |||||
for i in range(length): | |||||
trigrams.append(''.join(chars[i:i + 3])) | |||||
return trigrams | |||||
def add_char_heads(instance): | |||||
words = instance['word_lst'] | |||||
heads = instance['heads'] | |||||
char_heads = [] | |||||
char_index = 1 # 因此存在root节点所以需要从1开始 | |||||
head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 | |||||
for word, head in zip(words, heads): | |||||
char_head = [] | |||||
if len(word)>1: | |||||
char_head.append(char_index+1) | |||||
char_index += 1 | |||||
for _ in range(len(word)-2): | |||||
char_index += 1 | |||||
char_head.append(char_index) | |||||
char_index += 1 | |||||
char_head.append(head_end_indexes[head-1]) | |||||
char_heads.extend(char_head) | |||||
return char_heads | |||||
def add_char_labels(instance): | |||||
""" | |||||
将word_lst中的数据按照下面的方式设置label | |||||
比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. | |||||
对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label | |||||
:param instance: | |||||
:return: | |||||
""" | |||||
words = instance['word_lst'] | |||||
labels = instance['labels'] | |||||
char_labels = [] | |||||
for word, label in zip(words, labels): | |||||
for _ in range(len(word)-1): | |||||
char_labels.append('APP') | |||||
char_labels.append(label) | |||||
return char_labels | |||||
# add seg_targets | |||||
def add_segs(instance): | |||||
words = instance['word_lst'] | |||||
segs = [0]*len(instance['chars']) | |||||
index = 0 | |||||
for word in words: | |||||
index = index + len(word) - 1 | |||||
segs[index] = len(word)-1 | |||||
index = index + 1 | |||||
return segs | |||||
# add target_masks | |||||
def add_mask(instance): | |||||
words = instance['word_lst'] | |||||
mask = [] | |||||
for word in words: | |||||
mask.extend([0] * (len(word) - 1)) | |||||
mask.append(1) | |||||
return mask |
@@ -0,0 +1,311 @@ | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP import seq_len_to_mask | |||||
from fastNLP.modules import Embedding | |||||
def drop_input_independent(word_embeddings, dropout_emb): | |||||
batch_size, seq_length, _ = word_embeddings.size() | |||||
word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) | |||||
word_masks = torch.bernoulli(word_masks) | |||||
word_masks = word_masks.unsqueeze(dim=2) | |||||
word_embeddings = word_embeddings * word_masks | |||||
return word_embeddings | |||||
class CharBiaffineParser(BiaffineParser): | |||||
def __init__(self, char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers=3, | |||||
rnn_hidden_size=800, #单向的数量 | |||||
arc_mlp_size=500, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='lstm', | |||||
use_greedy_infer=False, | |||||
app_index = 0, | |||||
pre_chars_embed=None, | |||||
pre_bigrams_embed=None, | |||||
pre_trigrams_embed=None): | |||||
super(BiaffineParser, self).__init__() | |||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.char_embed = Embedding((char_vocab_size, emb_dim)) | |||||
self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) | |||||
self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) | |||||
if pre_chars_embed: | |||||
self.pre_char_embed = Embedding(pre_chars_embed) | |||||
self.pre_char_embed.requires_grad = False | |||||
if pre_bigrams_embed: | |||||
self.pre_bigram_embed = Embedding(pre_bigrams_embed) | |||||
self.pre_bigram_embed.requires_grad = False | |||||
if pre_trigrams_embed: | |||||
self.pre_trigram_embed = Embedding(pre_trigrams_embed) | |||||
self.pre_trigram_embed.requires_grad = False | |||||
self.timestep_drop = TimestepDropout(dropout) | |||||
self.encoder_name = encoder | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=emb_dim*3, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=emb_dim*3, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
else: | |||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.LeakyReLU(0.1), | |||||
TimestepDropout(p=dropout),) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||||
self.use_greedy_infer = use_greedy_infer | |||||
self.reset_parameters() | |||||
self.dropout = dropout | |||||
self.app_index = app_index | |||||
self.num_label = num_label | |||||
if self.app_index != 0: | |||||
raise ValueError("现在app_index必须等于0") | |||||
def reset_parameters(self): | |||||
for name, m in self.named_modules(): | |||||
if 'embed' in name: | |||||
pass | |||||
elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): | |||||
pass | |||||
else: | |||||
for p in m.parameters(): | |||||
if len(p.size())>1: | |||||
nn.init.xavier_normal_(p, gain=0.1) | |||||
else: | |||||
nn.init.uniform_(p, -0.1, 0.1) | |||||
def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, | |||||
pre_trigrams=None): | |||||
""" | |||||
max_len是包含root的 | |||||
:param chars: batch_size x max_len | |||||
:param ngrams: batch_size x max_len*ngram_per_char | |||||
:param seq_lens: batch_size | |||||
:param gold_heads: batch_size x max_len | |||||
:param pre_chars: batch_size x max_len | |||||
:param pre_ngrams: batch_size x max_len*ngram_per_char | |||||
:return dict: parsing results | |||||
arc_pred: [batch_size, seq_len, seq_len] | |||||
label_pred: [batch_size, seq_len, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||||
""" | |||||
# prepare embeddings | |||||
batch_size, seq_len = chars.shape | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | |||||
# get sequence mask | |||||
mask = seq_len_to_mask(seq_lens).long() | |||||
chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] | |||||
bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] | |||||
trigrams = self.trigram_embed(trigrams) | |||||
if pre_chars is not None: | |||||
pre_chars = self.pre_char_embed(pre_chars) | |||||
# pre_chars = self.pre_char_fc(pre_chars) | |||||
chars = pre_chars + chars | |||||
if pre_bigrams is not None: | |||||
pre_bigrams = self.pre_bigram_embed(pre_bigrams) | |||||
# pre_bigrams = self.pre_bigram_fc(pre_bigrams) | |||||
bigrams = bigrams + pre_bigrams | |||||
if pre_trigrams is not None: | |||||
pre_trigrams = self.pre_trigram_embed(pre_trigrams) | |||||
# pre_trigrams = self.pre_trigram_fc(pre_trigrams) | |||||
trigrams = trigrams + pre_trigrams | |||||
x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] | |||||
# encoder, extract features | |||||
if self.training: | |||||
x = drop_input_independent(x, self.dropout) | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
feat = self.timestep_drop(feat) | |||||
# for arc biaffine | |||||
# mlp, reduce dim | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
# biaffine arc classifier | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
# use gold or predicted arc to predict label | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | |||||
if self.training or self.use_greedy_infer: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
else: | |||||
heads = self.mst_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
assert self.training # must be training mode | |||||
if gold_heads is None: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
# heads: batch_size x max_len | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | |||||
label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] | |||||
# 这里限制一下,只有当head为下一个时,才能预测app这个label | |||||
arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ | |||||
.repeat(batch_size, 1) # batch_size x max_len | |||||
app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app | |||||
app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) | |||||
app_masks[:, :, 1:] = 0 | |||||
label_pred = label_pred.masked_fill(app_masks, -np.inf) | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
if head_pred is not None: | |||||
res_dict['head_pred'] = head_pred | |||||
return res_dict | |||||
@staticmethod | |||||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||||
""" | |||||
Compute loss. | |||||
:param arc_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param arc_true: [batch_size, seq_len] | |||||
:param label_true: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:return: loss value | |||||
""" | |||||
batch_size, seq_len, _ = arc_pred.shape | |||||
flip_mask = (mask == 0) | |||||
_arc_pred = arc_pred.clone() | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_true[:, 0].fill_(-1) | |||||
label_true[:, 0].fill_(-1) | |||||
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||||
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||||
return arc_nll + label_nll | |||||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): | |||||
""" | |||||
max_len是包含root的 | |||||
:param chars: batch_size x max_len | |||||
:param ngrams: batch_size x max_len*ngram_per_char | |||||
:param seq_lens: batch_size | |||||
:param pre_chars: batch_size x max_len | |||||
:param pre_ngrams: batch_size x max_len*ngram_per_cha | |||||
:return: | |||||
""" | |||||
res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, | |||||
pre_trigrams=pre_trigrams, gold_heads=None) | |||||
output = {} | |||||
output['arc_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
return output | |||||
class CharParser(nn.Module): | |||||
def __init__(self, char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers=3, | |||||
rnn_hidden_size=400, #单向的数量 | |||||
arc_mlp_size=500, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='var-lstm', | |||||
use_greedy_infer=False, | |||||
app_index = 0, | |||||
pre_chars_embed=None, | |||||
pre_bigrams_embed=None, | |||||
pre_trigrams_embed=None): | |||||
super().__init__() | |||||
self.parser = CharBiaffineParser(char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers, | |||||
rnn_hidden_size, #单向的数量 | |||||
arc_mlp_size, | |||||
label_mlp_size, | |||||
dropout, | |||||
encoder, | |||||
use_greedy_infer, | |||||
app_index, | |||||
pre_chars_embed=pre_chars_embed, | |||||
pre_bigrams_embed=pre_bigrams_embed, | |||||
pre_trigrams_embed=pre_trigrams_embed) | |||||
def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, | |||||
pre_trigrams=None): | |||||
res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, | |||||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||||
arc_pred = res_dict['arc_pred'] | |||||
label_pred = res_dict['label_pred'] | |||||
masks = res_dict['mask'] | |||||
loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) | |||||
return {'loss': loss} | |||||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): | |||||
res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, | |||||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||||
output = {} | |||||
output['head_preds'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_preds'] = label_pred | |||||
return output |
@@ -0,0 +1,65 @@ | |||||
from fastNLP.core.callback import Callback | |||||
import torch | |||||
from torch import nn | |||||
class OptimizerCallback(Callback): | |||||
def __init__(self, optimizer, scheduler, update_every=4): | |||||
super().__init__() | |||||
self._optimizer = optimizer | |||||
self.scheduler = scheduler | |||||
self._update_every = update_every | |||||
def on_backward_end(self): | |||||
if self.step % self._update_every==0: | |||||
# nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) | |||||
# self._optimizer.step() | |||||
self.scheduler.step() | |||||
# self.model.zero_grad() | |||||
class DevCallback(Callback): | |||||
def __init__(self, tester, metric_key='u_f1'): | |||||
super().__init__() | |||||
self.tester = tester | |||||
setattr(tester, 'verbose', 0) | |||||
self.metric_key = metric_key | |||||
self.record_best = False | |||||
self.best_eval_value = 0 | |||||
self.best_eval_res = None | |||||
self.best_dev_res = None # 存取dev的表现 | |||||
def on_valid_begin(self): | |||||
eval_res = self.tester.test() | |||||
metric_name = self.tester.metrics[0].__class__.__name__ | |||||
metric_value = eval_res[metric_name][self.metric_key] | |||||
if metric_value>self.best_eval_value: | |||||
self.best_eval_value = metric_value | |||||
self.best_epoch = self.trainer.epoch | |||||
self.record_best = True | |||||
self.best_eval_res = eval_res | |||||
self.test_eval_res = eval_res | |||||
eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
self.pbar.write(eval_str) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if self.record_best: | |||||
self.best_dev_res = eval_result | |||||
self.record_best = False | |||||
if is_better_eval: | |||||
self.best_dev_res_on_dev = eval_result | |||||
self.best_test_res_on_dev = self.test_eval_res | |||||
self.dev_epoch = self.epoch | |||||
def on_train_end(self): | |||||
print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, | |||||
self.tester._format_eval_results(self.best_eval_res), | |||||
self.tester._format_eval_results(self.best_dev_res))) | |||||
print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, | |||||
self.tester._format_eval_results(self.best_test_res_on_dev), | |||||
self.tester._format_eval_results(self.best_dev_res_on_dev))) |
@@ -0,0 +1,184 @@ | |||||
from fastNLP.core.metrics import MetricBase | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
import torch | |||||
class SegAppCharParseF1Metric(MetricBase): | |||||
# | |||||
def __init__(self, app_index): | |||||
super().__init__() | |||||
self.app_index = app_index | |||||
self.parse_head_tp = 0 | |||||
self.parse_label_tp = 0 | |||||
self.rec_tol = 0 | |||||
self.pre_tol = 0 | |||||
def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, | |||||
pun_masks): | |||||
""" | |||||
max_len是不包含root的character的长度 | |||||
:param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size | |||||
:param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size | |||||
:param head_preds: batch_size x max_len | |||||
:param label_preds: batch_size x max_len | |||||
:param seq_lens: | |||||
:param pun_masks: batch_size x | |||||
:return: | |||||
""" | |||||
# 去掉root | |||||
head_preds = head_preds[:, 1:].tolist() | |||||
label_preds = label_preds[:, 1:].tolist() | |||||
seq_lens = (seq_lens - 1).tolist() | |||||
# 先解码出words,POS,heads, labels, 对应的character范围 | |||||
for b in range(len(head_preds)): | |||||
seq_len = seq_lens[b] | |||||
head_pred = head_preds[b][:seq_len] | |||||
label_pred = label_preds[b][:seq_len] | |||||
words = [] # 存放[word_start, word_end),相对起始位置,不考虑root | |||||
heads = [] | |||||
labels = [] | |||||
ranges = [] # 对应该char是第几个word,长度是seq_len+1 | |||||
word_idx = 0 | |||||
word_start_idx = 0 | |||||
for idx, (label, head) in enumerate(zip(label_pred, head_pred)): | |||||
ranges.append(word_idx) | |||||
if label == self.app_index: | |||||
pass | |||||
else: | |||||
labels.append(label) | |||||
heads.append(head) | |||||
words.append((word_start_idx, idx+1)) | |||||
word_start_idx = idx+1 | |||||
word_idx += 1 | |||||
head_dep_tuple = [] # head在前面 | |||||
head_label_dep_tuple = [] | |||||
for idx, head in enumerate(heads): | |||||
span = words[idx] | |||||
if span[0]==span[1]-1 and pun_masks[b, span[0]]: | |||||
continue # exclude punctuations | |||||
if head == 0: | |||||
head_dep_tuple.append((('root', words[idx]))) | |||||
head_label_dep_tuple.append(('root', labels[idx], words[idx])) | |||||
else: | |||||
head_word_idx = ranges[head-1] | |||||
head_word_span = words[head_word_idx] | |||||
head_dep_tuple.append(((head_word_span, words[idx]))) | |||||
head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) | |||||
gold_head_dep_tuple = set(gold_word_pairs[b]) | |||||
gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) | |||||
for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): | |||||
if head_dep in gold_head_dep_tuple: | |||||
self.parse_head_tp += 1 | |||||
if head_label_dep in gold_head_label_dep_tuple: | |||||
self.parse_label_tp += 1 | |||||
self.pre_tol += len(head_dep_tuple) | |||||
self.rec_tol += len(gold_head_dep_tuple) | |||||
def get_metric(self, reset=True): | |||||
u_p = self.parse_head_tp / self.pre_tol | |||||
u_r = self.parse_head_tp / self.rec_tol | |||||
u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) | |||||
l_p = self.parse_label_tp / self.pre_tol | |||||
l_r = self.parse_label_tp / self.rec_tol | |||||
l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) | |||||
if reset: | |||||
self.parse_head_tp = 0 | |||||
self.parse_label_tp = 0 | |||||
self.rec_tol = 0 | |||||
self.pre_tol = 0 | |||||
return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), | |||||
'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} | |||||
class CWSMetric(MetricBase): | |||||
def __init__(self, app_index): | |||||
super().__init__() | |||||
self.app_index = app_index | |||||
self.pre = 0 | |||||
self.rec = 0 | |||||
self.tp = 0 | |||||
def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): | |||||
""" | |||||
:param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 | |||||
:param seg_masks: batch_size x max_len,只有在word结束的地方为1 | |||||
:param label_preds: batch_size x max_len | |||||
:param seq_lens: batch_size | |||||
:return: | |||||
""" | |||||
pred_masks = torch.zeros_like(seg_masks) | |||||
pred_segs = torch.zeros_like(seg_targets) | |||||
seq_lens = (seq_lens - 1).tolist() | |||||
for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): | |||||
seq_len = seq_lens[idx] | |||||
label_pred = label_pred[:seq_len] | |||||
word_len = 0 | |||||
for l_i, label in enumerate(label_pred): | |||||
if label==self.app_index and l_i!=len(label_pred)-1: | |||||
word_len += 1 | |||||
else: | |||||
pred_segs[idx, l_i] = word_len # 这个词的长度为word_len | |||||
pred_masks[idx, l_i] = 1 | |||||
word_len = 0 | |||||
right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 | |||||
self.rec += seg_masks.sum().item() | |||||
self.pre += pred_masks.sum().item() | |||||
# 且pred和target在同一个地方有值 | |||||
self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() | |||||
def get_metric(self, reset=True): | |||||
res = {} | |||||
res['rec'] = round(self.tp/(self.rec+1e-6), 4) | |||||
res['pre'] = round(self.tp/(self.pre+1e-6), 4) | |||||
res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) | |||||
if reset: | |||||
self.pre = 0 | |||||
self.rec = 0 | |||||
self.tp = 0 | |||||
return res | |||||
class ParserMetric(MetricBase): | |||||
def __init__(self, ): | |||||
super().__init__() | |||||
self.num_arc = 0 | |||||
self.num_label = 0 | |||||
self.num_sample = 0 | |||||
def get_metric(self, reset=True): | |||||
res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), | |||||
'LAS': round(self.num_label*1.0 / self.num_sample, 4)} | |||||
if reset: | |||||
self.num_sample = self.num_label = self.num_arc = 0 | |||||
return res | |||||
def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): | |||||
"""Evaluate the performance of prediction. | |||||
""" | |||||
if seq_lens is None: | |||||
seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) | |||||
else: | |||||
seq_mask = seq_len_to_mask(seq_lens.long(), float=False) | |||||
# mask out <root> tag | |||||
seq_mask[:, 0] = 0 | |||||
head_pred_correct = (head_preds == heads).__and__(seq_mask) | |||||
label_pred_correct = (label_preds == labels).__and__(head_pred_correct) | |||||
self.num_arc += head_pred_correct.float().sum().item() | |||||
self.num_label += label_pred_correct.float().sum().item() | |||||
self.num_sample += seq_mask.sum().item() | |||||
@@ -0,0 +1,16 @@ | |||||
Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) | |||||
### 准备数据 | |||||
1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. | |||||
2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 | |||||
3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 | |||||
### 运行代码 | |||||
``` | |||||
python train.py | |||||
``` | |||||
### 其它 | |||||
ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 | |||||
learning rate scheduler. |
@@ -0,0 +1,124 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader | |||||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
from torch import nn | |||||
from functools import partial | |||||
from reproduction.joint_cws_parse.models.CharParser import CharParser | |||||
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric | |||||
from fastNLP import cache_results, BucketSampler, Trainer | |||||
from torch import optim | |||||
from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback | |||||
from torch.optim.lr_scheduler import LambdaLR, StepLR | |||||
from fastNLP import Tester | |||||
from fastNLP import GradientClipCallback, LRScheduler | |||||
import os | |||||
def set_random_seed(random_seed=666): | |||||
import random, numpy, torch | |||||
random.seed(random_seed) | |||||
numpy.random.seed(random_seed) | |||||
torch.cuda.manual_seed(random_seed) | |||||
torch.random.manual_seed(random_seed) | |||||
uniform_init = partial(nn.init.normal_, std=0.02) | |||||
################################################### | |||||
# 需要变动的超参放到这里 | |||||
lr = 0.002 # 0.01~0.001 | |||||
dropout = 0.33 # 0.3~0.6 | |||||
weight_decay = 0 # 1e-5, 1e-6, 0 | |||||
arc_mlp_size = 500 # 200, 300 | |||||
rnn_hidden_size = 400 # 200, 300, 400 | |||||
rnn_layers = 3 # 2, 3 | |||||
encoder = 'var-lstm' # var-lstm, lstm | |||||
emb_size = 100 # 64 , 100 | |||||
label_mlp_size = 100 | |||||
batch_size = 32 | |||||
update_every = 4 | |||||
n_epochs = 100 | |||||
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||||
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||||
#################################################### | |||||
set_random_seed(1234) | |||||
device = 0 | |||||
# @cache_results('caches/{}.pkl'.format(data_name)) | |||||
# def get_data(): | |||||
data = CTBxJointLoader().process(data_folder) | |||||
char_labels_vocab = data.vocabs['char_labels'] | |||||
pre_chars_vocab = data.vocabs['pre_chars'] | |||||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||||
chars_vocab = data.vocabs['chars'] | |||||
bigrams_vocab = data.vocabs['bigrams'] | |||||
trigrams_vocab = data.vocabs['trigrams'] | |||||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||||
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||||
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||||
print(data) | |||||
model = CharParser(char_vocab_size=len(chars_vocab), | |||||
emb_dim=emb_size, | |||||
bigram_vocab_size=len(bigrams_vocab), | |||||
trigram_vocab_size=len(trigrams_vocab), | |||||
num_label=len(char_labels_vocab), | |||||
rnn_layers=rnn_layers, | |||||
rnn_hidden_size=rnn_hidden_size, | |||||
arc_mlp_size=arc_mlp_size, | |||||
label_mlp_size=label_mlp_size, | |||||
dropout=dropout, | |||||
encoder=encoder, | |||||
use_greedy_infer=False, | |||||
app_index=char_labels_vocab['APP'], | |||||
pre_chars_embed=pre_chars_embed, | |||||
pre_bigrams_embed=pre_bigrams_embed, | |||||
pre_trigrams_embed=pre_trigrams_embed) | |||||
metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) | |||||
metric2 = CWSMetric(char_labels_vocab['APP']) | |||||
metrics = [metric1, metric2] | |||||
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, | |||||
weight_decay=weight_decay, betas=[0.9, 0.9]) | |||||
sampler = BucketSampler(seq_len_field_name='seq_lens') | |||||
callbacks = [] | |||||
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||||
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||||
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||||
# callbacks.append(optim_callback) | |||||
scheduler_callback = LRScheduler(scheduler) | |||||
callbacks.append(scheduler_callback) | |||||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||||
tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, | |||||
batch_size=64, device=device, verbose=0) | |||||
dev_callback = DevCallback(tester) | |||||
callbacks.append(dev_callback) | |||||
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||||
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||||
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||||
device=device, callbacks=callbacks, update_every=update_every) | |||||
trainer.train() |
@@ -0,0 +1,100 @@ | |||||
# Matching任务模型复现 | |||||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||||
复现的模型有(按论文发表时间顺序排序): | |||||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||||
- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | |||||
- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | |||||
- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). | |||||
论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | |||||
# 数据集及复现结果汇总 | |||||
使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 | |||||
'\-'表示我们仍未复现或者论文原文没有汇报 | |||||
model name | SNLI | MNLI | RTE | QNLI | Quora | |||||
:---: | :---: | :---: | :---: | :---: | :---: | |||||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||||
BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | |||||
*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 | |||||
# 数据集复现结果及其他主要模型对比 | |||||
## SNLI | |||||
[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | |||||
Performance on Test set: | |||||
model name | ESIM | DIIN | MwAN | [GPT1.0](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | [BERT-Large+SRL](https://arxiv.org/pdf/1809.02794.pdf) | [MT-DNN](https://arxiv.org/pdf/1901.11504.pdf) | |||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||||
__performance__ | 88.0 | 88.0 | 88.3 | 89.9 | 91.3 | 91.6 | | |||||
### 基于fastNLP复现的结果 | |||||
Performance on Test set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
## MNLI | |||||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||||
Performance on Test set(matched/mismatched): | |||||
model name | ESIM | DIIN | MwAN | GPT1.0 | BERT-Base | MT-DNN | |||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||||
__performance__ | 72.4/72.1 | 78.8/77.8 | 78.5/77.7 | 82.1/81.4 | 84.6/83.4 | 87.9/87.4 | | |||||
### 基于fastNLP复现的结果 | |||||
Performance on Test set(matched/mismatched): | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||||
:---: | :---: | :---: | :---: | :---: | :---: | | |||||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
## RTE | |||||
Still in progress. | |||||
## QNLI | |||||
### From GLUE baselines | |||||
[Link to GLUE leaderboard](https://gluebenchmark.com/leaderboard) | |||||
Performance on Test set: | |||||
#### LSTM-based | |||||
model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | |||||
:---: | :---: | :---: | :---: | :---: | | |||||
__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | |||||
*这些LSTM-based的baseline是由QNLI官方实现并测试的。 | |||||
#### Transformer-based | |||||
model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | |||||
:---: | :---: | :---: | :---: | :---: | | |||||
__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | |||||
### 基于fastNLP复现的结果 | |||||
Performance on __Dev__ set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT | |||||
:---: | :---: | :---: | :---: | :---: | :---: | |||||
__performance__ | - | 76.97 | - | 74.6 | - | |||||
## Quora | |||||
Still in progress. | |||||
@@ -5,8 +5,8 @@ from typing import Union, Dict | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
from fastNLP.modules.encoder._bert import BertTokenizer | from fastNLP.modules.encoder._bert import BertTokenizer | ||||
@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader): | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | ||||
读取Matching任务的数据集 | 读取Matching任务的数据集 | ||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | """ | ||||
def __init__(self, paths: dict=None): | def __init__(self, paths: dict=None): | ||||
""" | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
self.paths = paths | self.paths = paths | ||||
def _load(self, path): | def _load(self, path): | ||||
@@ -34,7 +33,8 @@ class MatchingLoader(DataSetLoader): | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | ||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | ||||
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | ||||
""" | """ | ||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | ||||
@@ -49,6 +49,8 @@ class MatchingLoader(DataSetLoader): | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | ||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | ||||
:param bool get_index: 是否需要根据词表将文本转为index | :param bool get_index: 是否需要根据词表将文本转为index | ||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | ||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | ||||
于此同时其他field不会被设置为input。默认值为True。 | 于此同时其他field不会被设置为input。默认值为True。 | ||||
@@ -169,6 +171,9 @@ class MatchingLoader(DataSetLoader): | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | ||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | ||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | if cut_text is not None: | ||||
for data_name, data_set in data_info.datasets.items(): | for data_name, data_set in data_info.datasets.items(): | ||||
for fields in data_set.get_field_names(): | for fields in data_set.get_field_names(): | ||||
@@ -180,7 +185,7 @@ class MatchingLoader(DataSetLoader): | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | assert len(data_set_list) > 0, f'There are NO data sets in data info!' | ||||
if bert_tokenizer is None: | if bert_tokenizer is None: | ||||
words_vocab = Vocabulary() | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | ||||
field_name=[n for n in data_set_list[0].get_field_names() | field_name=[n for n in data_set_list[0].get_field_names() | ||||
if (Const.INPUT in n)], | if (Const.INPUT in n)], | ||||
@@ -202,6 +207,20 @@ class MatchingLoader(DataSetLoader): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | ||||
is_input=auto_set_input, is_target=auto_set_target) | is_input=auto_set_input, is_target=auto_set_target) | ||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | for data_name, data_set in data_info.datasets.items(): | ||||
if isinstance(set_input, list): | if isinstance(set_input, list): | ||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | ||||
@@ -267,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader): | |||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
'train': 'train.tsv', | 'train': 'train.tsv', | ||||
'dev': 'dev.tsv', | 'dev': 'dev.tsv', | ||||
# 'test': 'test.tsv' # test set has not label | |||||
'test': 'test.tsv' # test set has not label | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
self.fields = { | self.fields = { | ||||
@@ -281,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader): | |||||
ds = CSVLoader._load(self, path) | ds = CSVLoader._load(self, path) | ||||
for k, v in self.fields.items(): | for k, v in self.fields.items(): | ||||
ds.rename_field(k, v) | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | for fields in ds.get_all_fields(): | ||||
if Const.INPUT in fields: | if Const.INPUT in fields: | ||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ||||
@@ -306,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
'train': 'train.tsv', | 'train': 'train.tsv', | ||||
'dev': 'dev.tsv', | 'dev': 'dev.tsv', | ||||
# 'test': 'test.tsv' # test set has not label | |||||
'test': 'test.tsv' # test set has not label | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
self.fields = { | self.fields = { | ||||
@@ -320,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||||
ds = CSVLoader._load(self, path) | ds = CSVLoader._load(self, path) | ||||
for k, v in self.fields.items(): | for k, v in self.fields.items(): | ||||
ds.rename_field(k, v) | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | for fields in ds.get_all_fields(): | ||||
if Const.INPUT in fields: | if Const.INPUT in fields: | ||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ||||
@@ -332,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | """ | ||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | ||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | words1: list(str),第一句文本, premise | ||||
words2: list(str), 第二句文本, hypothesis | words2: list(str), 第二句文本, hypothesis | ||||
@@ -348,6 +369,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
'dev_mismatched': 'dev_mismatched.tsv', | 'dev_mismatched': 'dev_mismatched.tsv', | ||||
'test_matched': 'test_matched.tsv', | 'test_matched': 'test_matched.tsv', | ||||
'test_mismatched': 'test_mismatched.tsv', | 'test_mismatched': 'test_mismatched.tsv', | ||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
CSVLoader.__init__(self, sep='\t') | CSVLoader.__init__(self, sep='\t') | ||||
@@ -364,6 +389,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
if k in ds.get_field_names(): | if k in ds.get_field_names(): | ||||
ds.rename_field(k, v) | ds.rename_field(k, v) | ||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | parentheses_table = str.maketrans({'(': None, ')': None}) | ||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | ||||
@@ -376,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | class QuoraLoader(MatchingLoader, CSVLoader): | ||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | def __init__(self, paths: dict=None): | ||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
@@ -0,0 +1,102 @@ | |||||
import random | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader | |||||
from reproduction.matching.model.bert import BertForNLI | |||||
# define hyper-parameters | |||||
class BERTConfig: | |||||
task = 'snli' | |||||
batch_size_per_gpu = 6 | |||||
n_epochs = 6 | |||||
lr = 2e-5 | |||||
seq_len_type = 'bert' | |||||
seed = 42 | |||||
train_dataset_name = 'train' | |||||
dev_dataset_name = 'dev' | |||||
test_dataset_name = 'test' | |||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||||
arg = BERTConfig() | |||||
# set random seed | |||||
random.seed(arg.seed) | |||||
np.random.seed(arg.seed) | |||||
torch.manual_seed(arg.seed) | |||||
n_gpu = torch.cuda.device_count() | |||||
if n_gpu > 0: | |||||
torch.cuda.manual_seed_all(arg.seed) | |||||
# load data set | |||||
if arg.task == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'rte': | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'mnli': | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'quora': | |||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
else: | |||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||||
# define model | |||||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||||
# define trainer | |||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
n_epochs=arg.n_epochs, print_every=-1, | |||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1, | |||||
save_path=arg.save_path) | |||||
# train model | |||||
trainer.train(load_best_model=True) | |||||
# define tester | |||||
tester = Tester( | |||||
data=data_info.datasets[arg.test_dataset_name], | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
) | |||||
# test model | |||||
tester.test() | |||||
@@ -1,47 +1,103 @@ | |||||
import argparse | |||||
import random | |||||
import numpy as np | |||||
import torch | import torch | ||||
from torch.optim import Adamax | |||||
from torch.optim.lr_scheduler import StepLR | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||||
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | ||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader | |||||
from reproduction.matching.model.esim import ESIMModel | from reproduction.matching.model.esim import ESIMModel | ||||
argument = argparse.ArgumentParser() | |||||
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||||
argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||||
argument.add_argument('--n-epochs', type=int, default=100) | |||||
argument.add_argument('--lr', type=float, default=1e-4) | |||||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||||
argument.add_argument('--save-dir', type=str, default=None) | |||||
arg = argument.parse_args() | |||||
bert_dirs = 'path/to/bert/dir' | |||||
# define hyper-parameters | |||||
class ESIMConfig: | |||||
task = 'snli' | |||||
embedding = 'glove' | |||||
batch_size_per_gpu = 196 | |||||
n_epochs = 30 | |||||
lr = 2e-3 | |||||
seq_len_type = 'seq_len' | |||||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||||
# mask表示用0/1掩码矩阵来表示长度信息; | |||||
seed = 42 | |||||
train_dataset_name = 'train' | |||||
dev_dataset_name = 'dev' | |||||
test_dataset_name = 'test' | |||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
arg = ESIMConfig() | |||||
# set random seed | |||||
random.seed(arg.seed) | |||||
np.random.seed(arg.seed) | |||||
torch.manual_seed(arg.seed) | |||||
n_gpu = torch.cuda.device_count() | |||||
if n_gpu > 0: | |||||
torch.cuda.manual_seed_all(arg.seed) | |||||
# load data set | # load data set | ||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, | |||||
) | |||||
if arg.task == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'rte': | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'mnli': | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'quora': | |||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
else: | |||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||||
# load embedding | # load embedding | ||||
if arg.embedding == 'elmo': | if arg.embedding == 'elmo': | ||||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | ||||
elif arg.embedding == 'glove': | elif arg.embedding == 'glove': | ||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||||
else: | else: | ||||
raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||||
# define model | # define model | ||||
model = ESIMModel(embedding) | |||||
model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||||
# define optimizer and callback | |||||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||||
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 | |||||
callbacks = [ | |||||
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) | |||||
LRScheduler(scheduler), | |||||
] | |||||
# define trainer | # define trainer | ||||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=optimizer, | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
n_epochs=arg.n_epochs, print_every=-1, | n_epochs=arg.n_epochs, print_every=-1, | ||||
dev_data=data_info.datasets['dev'], | |||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | metrics=AccuracyMetric(), metric_key='acc', | ||||
device=[i for i in range(torch.cuda.device_count())], | device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1, | check_code_level=-1, | ||||
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True) | |||||
# define tester | # define tester | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets['test'], | |||||
data=data_info.datasets[arg.test_dataset_name], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
@@ -81,6 +81,7 @@ class ESIMModel(BaseModel): | |||||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | ||||
logits = torch.tanh(self.classifier(out)) | logits = torch.tanh(self.classifier(out)) | ||||
# logits = self.classifier(out) | |||||
if target is not None: | if target is not None: | ||||
loss_fct = CrossEntropyLoss() | loss_fct = CrossEntropyLoss() | ||||
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel): | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, **kwargs): | def predict(self, **kwargs): | ||||
return self.forward(**kwargs) | |||||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||||
return {Const.OUTPUT: pred} | |||||
# input [batch_size, len , hidden] | # input [batch_size, len , hidden] | ||||
# mask [batch_size, len] (111...00) | # mask [batch_size, len] (111...00) | ||||
@@ -127,7 +129,7 @@ class BiRNN(nn.Module): | |||||
def forward(self, x, x_mask): | def forward(self, x, x_mask): | ||||
# Sort x | # Sort x | ||||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||||
lengths = x_mask.data.eq(1).long().sum(1) | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | _, idx_sort = torch.sort(lengths, dim=0, descending=True) | ||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | _, idx_unsort = torch.sort(idx_sort, dim=0) | ||||
lengths = list(lengths[idx_sort]) | lengths = list(lengths[idx_sort]) | ||||
@@ -0,0 +1,13 @@ | |||||
# NER任务模型复现 | |||||
这里使用fastNLP复现经典的BiLSTM-CNN的NER任务的模型,旨在达到与论文中相符的性能。 | |||||
论文链接[Named Entity Recognition with Bidirectional LSTM-CNNs](https://arxiv.org/pdf/1511.08308.pdf) | |||||
# 数据集及复现结果汇总 | |||||
使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道) | |||||
model name | Conll2003 | Ontonotes | |||||
:---: | :---: | :---: | |||||
BiLSTM-CNN | 91.17/90.91 | 86.47/86.35 | | |||||
@@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
} | } | ||||
如果paths为不合法的,将直接进行raise相应的错误 | 如果paths为不合法的,将直接进行raise相应的错误 | ||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||||
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(paths, str): | if isinstance(paths, str): | ||||
if os.path.isfile(paths): | if os.path.isfile(paths): | ||||
return {'train': paths} | return {'train': paths} | ||||
elif os.path.isdir(paths): | elif os.path.isdir(paths): | ||||
train_fp = os.path.join(paths, 'train.txt') | |||||
if not os.path.isfile(train_fp): | |||||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | |||||
files = {'train': train_fp} | |||||
for filename in ['dev.txt', 'test.txt']: | |||||
fp = os.path.join(paths, filename) | |||||
if os.path.isfile(fp): | |||||
files[filename.split('.')[0]] = fp | |||||
filenames = os.listdir(paths) | |||||
files = {} | |||||
for filename in filenames: | |||||
path_pair = None | |||||
if 'train' in filename: | |||||
path_pair = ('train', filename) | |||||
if 'dev' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('dev', filename) | |||||
if 'test' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('test', filename) | |||||
if path_pair: | |||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||||
return files | return files | ||||
else: | else: | ||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | raise FileNotFoundError(f"{paths} is not a valid file path.") | ||||
@@ -1,5 +1,5 @@ | |||||
numpy | |||||
torch>=0.4.0 | |||||
tqdm | |||||
nltk | |||||
numpy>=1.14.2 | |||||
torch>=1.0.0 | |||||
tqdm>=4.28.1 | |||||
nltk>=3.4.1 | |||||
requests | requests |
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
import os | import os | ||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | |||||
from fastNLP.io.dataset_loader import SSTLoader | |||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||||
from reproduction.text_classification.data.yelpLoader import yelpLoader | from reproduction.text_classification.data.yelpLoader import yelpLoader | ||||
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): | |||||
print(info.vocabs) | print(info.vocabs) | ||||
print(info.datasets) | print(info.datasets) | ||||
os.remove(train), os.remove(test) | os.remove(train), os.remove(test) | ||||
def test_import(self): | |||||
import fastNLP | |||||
from fastNLP.io import SNLILoader | |||||
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||||
get_index=True, seq_len_type='seq_len') | |||||
assert 'train' in ds.datasets | |||||
assert len(ds.datasets) == 1 | |||||
assert len(ds.datasets['train']) == 3 |
@@ -8,8 +8,9 @@ from fastNLP.models.bert import * | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForSequenceClassification(2) | |||||
model = BertForSequenceClassification(2, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -22,8 +23,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_2(self): | def test_bert_2(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForMultipleChoice(2) | |||||
model = BertForMultipleChoice(2, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -36,8 +38,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_3(self): | def test_bert_3(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForTokenClassification(7) | |||||
model = BertForTokenClassification(7, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -50,8 +53,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_4(self): | def test_bert_4(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForQuestionAnswering() | |||||
model = BertForQuestionAnswering(BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
model = BertModel(vocab_size=32000, hidden_size=768, | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
config = BertConfig(32000) | |||||
model = BertModel(config) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -18,4 +19,4 @@ class TestBert(unittest.TestCase): | |||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | ||||
for layer in all_encoder_layers: | for layer in all_encoder_layers: | ||||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | self.assertEqual(tuple(layer.shape), (2, 3, 768)) | ||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) | |||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |
@@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data.rename_field('label', 'label_seq') | train_data.rename_field('label', 'label_seq') | ||||
test_data.rename_field('label', 'label_seq') | test_data.rename_field('label', 'label_seq') | ||||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||||
loss = CrossEntropyLoss(target="label_seq") | |||||
metric = AccuracyMetric(target="label_seq") | metric = AccuracyMetric(target="label_seq") | ||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
@@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | ||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
loss=CrossEntropyLoss(target="label_seq"), | |||||
metrics=AccuracyMetric(target="label_seq"), | metrics=AccuracyMetric(target="label_seq"), | ||||
save_path=None, | save_path=None, | ||||
batch_size=32, | batch_size=32, | ||||