@@ -6,13 +6,14 @@ | |||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | |||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | |||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。 | |||
- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等; | |||
- 详尽的中文文档以供查阅; | |||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; | |||
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; | |||
- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等; | |||
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅; | |||
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | |||
- 封装CNNText,Biaffine等模型可供直接使用; | |||
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/) | |||
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | |||
@@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地 | |||
fastNLP 依赖如下包: | |||
+ numpy | |||
+ torch>=0.4.0 | |||
+ tqdm | |||
+ nltk | |||
+ numpy>=1.14.2 | |||
+ torch>=1.0.0 | |||
+ tqdm>=4.28.1 | |||
+ nltk>=3.4.1 | |||
+ requests | |||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。 | |||
在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装 | |||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 | |||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装 | |||
```shell | |||
pip install fastNLP | |||
@@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助 | |||
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | |||
你可以在以下两个地方查看相关信息 | |||
- [介绍](reproduction/) | |||
- [源码](fastNLP/models/) | |||
- [模型介绍](reproduction/) | |||
- [模型源码](fastNLP/models/) | |||
## 项目结构 | |||
@@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||
</tr> | |||
<tr> | |||
<td><b> fastNLP.core </b></td> | |||
<td> 实现了核心功能,包括数据处理组件、训练器、测速器等 </td> | |||
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td> | |||
</tr> | |||
<tr> | |||
<td><b> fastNLP.models </b></td> | |||
@@ -0,0 +1,88 @@ | |||
import threading | |||
import torch | |||
from torch.nn.parallel.parallel_apply import get_a_var | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
from torch.nn.parallel.replicate import replicate | |||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
r"""Applies each `module` in :attr:`modules` in parallel on arguments | |||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) | |||
on each of :attr:`devices`. | |||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and | |||
:attr:`devices` (if given) should all have same length. Moreover, each | |||
element of :attr:`inputs` can either be a single object as the only argument | |||
to a module, or a collection of positional arguments. | |||
""" | |||
assert len(modules) == len(inputs) | |||
if kwargs_tup is not None: | |||
assert len(modules) == len(kwargs_tup) | |||
else: | |||
kwargs_tup = ({},) * len(modules) | |||
if devices is not None: | |||
assert len(modules) == len(devices) | |||
else: | |||
devices = [None] * len(modules) | |||
lock = threading.Lock() | |||
results = {} | |||
grad_enabled = torch.is_grad_enabled() | |||
def _worker(i, module, input, kwargs, device=None): | |||
torch.set_grad_enabled(grad_enabled) | |||
if device is None: | |||
device = get_a_var(input).get_device() | |||
try: | |||
with torch.cuda.device(device): | |||
# this also avoids accidental slicing of `input` if it is a Tensor | |||
if not isinstance(input, (list, tuple)): | |||
input = (input,) | |||
output = getattr(module, func_name)(*input, **kwargs) | |||
with lock: | |||
results[i] = output | |||
except Exception as e: | |||
with lock: | |||
results[i] = e | |||
if len(modules) > 1: | |||
threads = [threading.Thread(target=_worker, | |||
args=(i, module, input, kwargs, device)) | |||
for i, (module, input, kwargs, device) in | |||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||
for thread in threads: | |||
thread.start() | |||
for thread in threads: | |||
thread.join() | |||
else: | |||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||
outputs = [] | |||
for i in range(len(inputs)): | |||
output = results[i] | |||
if isinstance(output, Exception): | |||
raise output | |||
outputs.append(output) | |||
return outputs | |||
def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
""" | |||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||
:param str, func_name: 对network中的这个函数进行多卡运行 | |||
:param device_ids: nn.DataParallel中的device_ids | |||
:param output_device: nn.DataParallel中的output_device | |||
:return: | |||
""" | |||
def wrapper(network, *inputs, **kwargs): | |||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||
if len(device_ids) == 1: | |||
return getattr(network, func_name)(*inputs[0], **kwargs[0]) | |||
replicas = replicate(network, device_ids[:len(inputs)]) | |||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||
return gather(outputs, output_device) | |||
return wrapper |
@@ -113,7 +113,7 @@ class Callback(object): | |||
@property | |||
def n_steps(self): | |||
"""Trainer一共会运行多少步""" | |||
"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||
return self._trainer.n_steps | |||
@property | |||
@@ -181,7 +181,7 @@ class Callback(object): | |||
:param dict batch_x: DataSet中被设置为input的field的batch。 | |||
:param dict batch_y: DataSet中被设置为target的field的batch。 | |||
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 | |||
情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 | |||
情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。 | |||
:return: | |||
""" | |||
pass | |||
@@ -399,10 +399,11 @@ class GradientClipCallback(Callback): | |||
self.clip_value = clip_value | |||
def on_backward_end(self): | |||
if self.parameters is None: | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
if self.step%self.update_every==0: | |||
if self.parameters is None: | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
class EarlyStopCallback(Callback): | |||
@@ -20,6 +20,7 @@ from collections import defaultdict | |||
import torch | |||
import torch.nn.functional as F | |||
from ..core.const import Const | |||
from .utils import _CheckError | |||
from .utils import _CheckRes | |||
from .utils import _build_args | |||
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
class LossBase(object): | |||
""" | |||
所有loss的基类。如果想了解其中的原理,请查看源码。 | |||
@@ -95,22 +97,7 @@ class LossBase(object): | |||
# if func_spect.varargs: | |||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||
# f"positional argument.).") | |||
def _fast_param_map(self, pred_dict, target_dict): | |||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
such as pred_dict has one element, target_dict has one element | |||
:param pred_dict: | |||
:param target_dict: | |||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
""" | |||
fast_param = {} | |||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
fast_param['pred'] = list(pred_dict.values())[0] | |||
fast_param['target'] = list(target_dict.values())[0] | |||
return fast_param | |||
return fast_param | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
""" | |||
:param dict pred_dict: 模型的forward函数返回的dict | |||
@@ -118,11 +105,7 @@ class LossBase(object): | |||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | |||
:return: | |||
""" | |||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||
if fast_param: | |||
loss = self.get_loss(**fast_param) | |||
return loss | |||
if not self._checked: | |||
# 1. check consistence between signature and _param_map | |||
func_spect = inspect.getfullargspec(self.get_loss) | |||
@@ -212,7 +195,6 @@ class LossFunc(LossBase): | |||
if not isinstance(key_map, dict): | |||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | |||
self._init_param_map(key_map, **kwargs) | |||
class CrossEntropyLoss(LossBase): | |||
@@ -226,6 +208,7 @@ class CrossEntropyLoss(LossBase): | |||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持'mean','sum'和'none'. | |||
Example:: | |||
@@ -233,21 +216,25 @@ class CrossEntropyLoss(LossBase): | |||
""" | |||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): | |||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): | |||
super(CrossEntropyLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.padding_idx = padding_idx | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
def get_loss(self, pred, target, seq_len=None): | |||
if pred.dim()>2: | |||
pred = pred.view(-1, pred.size(-1)) | |||
target = target.view(-1) | |||
if pred.dim() > 2: | |||
if pred.size(1) != target.size(1): | |||
pred = pred.transpose(1, 2) | |||
pred = pred.reshape(-1, pred.size(-1)) | |||
target = target.reshape(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).view(-1).eq(0) | |||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx) | |||
ignore_index=self.padding_idx, reduction=self.reduction) | |||
class L1Loss(LossBase): | |||
@@ -258,15 +245,18 @@ class L1Loss(LossBase): | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | |||
:param str reduction: 支持'mean','sum'和'none'. | |||
""" | |||
def __init__(self, pred=None, target=None): | |||
def __init__(self, pred=None, target=None, reduction='mean'): | |||
super(L1Loss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
def get_loss(self, pred, target): | |||
return F.l1_loss(input=pred, target=target) | |||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | |||
class BCELoss(LossBase): | |||
@@ -277,14 +267,17 @@ class BCELoss(LossBase): | |||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||
:param str reduction: 支持'mean','sum'和'none'. | |||
""" | |||
def __init__(self, pred=None, target=None): | |||
def __init__(self, pred=None, target=None, reduction='mean'): | |||
super(BCELoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
def get_loss(self, pred, target): | |||
return F.binary_cross_entropy(input=pred, target=target) | |||
return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | |||
class NLLLoss(LossBase): | |||
@@ -295,14 +288,20 @@ class NLLLoss(LossBase): | |||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持'mean','sum'和'none'. | |||
""" | |||
def __init__(self, pred=None, target=None): | |||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||
super(NLLLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
self.ignore_idx = ignore_idx | |||
def get_loss(self, pred, target): | |||
return F.nll_loss(input=pred, target=target) | |||
return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | |||
class LossInForward(LossBase): | |||
@@ -314,7 +313,7 @@ class LossInForward(LossBase): | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
def __init__(self, loss_key='loss'): | |||
def __init__(self, loss_key=Const.LOSS): | |||
super().__init__() | |||
if not isinstance(loss_key, str): | |||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||
@@ -9,6 +9,9 @@ __all__ = [ | |||
] | |||
import torch | |||
import math | |||
import torch | |||
from torch.optim.optimizer import Optimizer as TorchOptimizer | |||
class Optimizer(object): | |||
@@ -97,3 +100,110 @@ class Adam(Optimizer): | |||
return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings) | |||
else: | |||
return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings) | |||
class AdamW(TorchOptimizer): | |||
r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. | |||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. | |||
Arguments: | |||
params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
lr (float, optional): learning rate (default: 1e-3) | |||
betas (Tuple[float, float], optional): coefficients used for computing | |||
running averages of gradient and its square (default: (0.9, 0.99)) | |||
eps (float, optional): term added to the denominator to improve | |||
numerical stability (default: 1e-8) | |||
weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this | |||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||
(default: False) | |||
.. _Adam\: A Method for Stochastic Optimization: | |||
https://arxiv.org/abs/1412.6980 | |||
.. _Decoupled Weight Decay Regularization: | |||
https://arxiv.org/abs/1711.05101 | |||
.. _On the Convergence of Adam and Beyond: | |||
https://openreview.net/forum?id=ryQu7f-RZ | |||
""" | |||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |||
weight_decay=1e-2, amsgrad=False): | |||
if not 0.0 <= lr: | |||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||
if not 0.0 <= eps: | |||
raise ValueError("Invalid epsilon value: {}".format(eps)) | |||
if not 0.0 <= betas[0] < 1.0: | |||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||
if not 0.0 <= betas[1] < 1.0: | |||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||
defaults = dict(lr=lr, betas=betas, eps=eps, | |||
weight_decay=weight_decay, amsgrad=amsgrad) | |||
super(AdamW, self).__init__(params, defaults) | |||
def __setstate__(self, state): | |||
super(AdamW, self).__setstate__(state) | |||
for group in self.param_groups: | |||
group.setdefault('amsgrad', False) | |||
def step(self, closure=None): | |||
"""Performs a single optimization step. | |||
Arguments: | |||
closure (callable, optional): A closure that reevaluates the model | |||
and returns the loss. | |||
""" | |||
loss = None | |||
if closure is not None: | |||
loss = closure() | |||
for group in self.param_groups: | |||
for p in group['params']: | |||
if p.grad is None: | |||
continue | |||
# Perform stepweight decay | |||
p.data.mul_(1 - group['lr'] * group['weight_decay']) | |||
# Perform optimization step | |||
grad = p.grad.data | |||
if grad.is_sparse: | |||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | |||
amsgrad = group['amsgrad'] | |||
state = self.state[p] | |||
# State initialization | |||
if len(state) == 0: | |||
state['step'] = 0 | |||
# Exponential moving average of gradient values | |||
state['exp_avg'] = torch.zeros_like(p.data) | |||
# Exponential moving average of squared gradient values | |||
state['exp_avg_sq'] = torch.zeros_like(p.data) | |||
if amsgrad: | |||
# Maintains max of all exp. moving avg. of sq. grad. values | |||
state['max_exp_avg_sq'] = torch.zeros_like(p.data) | |||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |||
if amsgrad: | |||
max_exp_avg_sq = state['max_exp_avg_sq'] | |||
beta1, beta2 = group['betas'] | |||
state['step'] += 1 | |||
# Decay the first and second moment running average coefficient | |||
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |||
if amsgrad: | |||
# Maintains the maximum of all 2nd moment running avg. till now | |||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | |||
# Use the max. for normalizing running avg. of gradient | |||
denom = max_exp_avg_sq.sqrt().add_(group['eps']) | |||
else: | |||
denom = exp_avg_sq.sqrt().add_(group['eps']) | |||
bias_correction1 = 1 - beta1 ** state['step'] | |||
bias_correction2 = 1 - beta2 ** state['step'] | |||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | |||
p.data.addcdiv_(-step_size, exp_avg, denom) | |||
return loss |
@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
class Predictor(object): | |||
""" | |||
An interface for predicting outputs based on trained models. | |||
一个根据训练模型预测输出的预测器(Predictor) | |||
It does not care about evaluations of the model, which is different from Tester. | |||
This is a high-level model wrapper to be called by FastNLP. | |||
This class does not share any operations with Trainer and Tester. | |||
Currently, Predictor does not support GPU. | |||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
def __init__(self, network): | |||
@@ -30,18 +30,19 @@ class Predictor(object): | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
def predict(self, data, seq_len_field_name=None): | |||
"""Perform inference using the trained model. | |||
def predict(self, data: DataSet, seq_len_field_name=None): | |||
"""用已经训练好的模型进行inference. | |||
:param data: a DataSet object. | |||
:param str seq_len_field_name: field name indicating sequence lengths | |||
:return: list of batch outputs | |||
:param fastNLP.DataSet data: 待预测的数据集 | |||
:param str seq_len_field_name: 表示序列长度信息的field名字 | |||
:return: dict dict里面的内容为模型预测的结果 | |||
""" | |||
if not isinstance(data, DataSet): | |||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
prev_training = self.network.training | |||
self.network.eval() | |||
network_device = _get_model_device(self.network) | |||
batch_output = defaultdict(list) | |||
@@ -74,4 +75,5 @@ class Predictor(object): | |||
else: | |||
batch_output[key].append(value) | |||
self.network.train(prev_training) | |||
return batch_output |
@@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||
""" | |||
import warnings | |||
import torch | |||
import torch.nn as nn | |||
@@ -48,6 +46,8 @@ from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _data_parallel_wrapper | |||
from functools import partial | |||
__all__ = [ | |||
"Tester" | |||
@@ -104,26 +104,27 @@ class Tester(object): | |||
self.data_iterator = data | |||
else: | |||
raise TypeError("data type {} not support".format(type(data))) | |||
# 如果是DataParallel将没有办法使用predict方法 | |||
if isinstance(self._model, nn.DataParallel): | |||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | |||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | |||
" while DataParallel has no predict() function.") | |||
self._model = self._model.module | |||
# check predict | |||
if hasattr(self._model, 'predict'): | |||
self._predict_func = self._model.predict | |||
if not callable(self._predict_func): | |||
_model_name = model.__class__.__name__ | |||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||
f"for evaluation, not `{type(self._predict_func)}`.") | |||
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ | |||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', | |||
self._model.device_ids, | |||
self._model.output_device), | |||
network=self._model.module) | |||
self._predict_func = self._model.module.predict | |||
else: | |||
self._predict_func = self._model.predict | |||
self._predict_func_wrapper = self._model.predict | |||
else: | |||
if isinstance(model, nn.DataParallel): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = self._model.forward | |||
self._predict_func = self._model.module.forward | |||
else: | |||
self._predict_func = self._model.forward | |||
self._predict_func_wrapper = self._model.forward | |||
def test(self): | |||
"""开始进行验证,并返回验证结果。 | |||
@@ -180,7 +181,7 @@ class Tester(object): | |||
def _data_forward(self, func, x): | |||
"""A forward pass of the model. """ | |||
x = _build_args(func, **x) | |||
y = func(**x) | |||
y = self._predict_func_wrapper(**x) | |||
return y | |||
def _format_eval_results(self, results): | |||
@@ -454,7 +454,7 @@ class Trainer(object): | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
metric_key=self.metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
self.model = _move_model_to_device(model, device=device) | |||
@@ -473,7 +473,7 @@ class Trainer(object): | |||
self.best_dev_step = None | |||
self.best_dev_perf = None | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -17,7 +17,6 @@ import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -278,6 +277,7 @@ def _move_model_to_device(model, device): | |||
return model | |||
def _get_model_device(model): | |||
""" | |||
传入一个nn.Module的模型,获取它所在的device | |||
@@ -11,21 +11,35 @@ | |||
""" | |||
__all__ = [ | |||
'EmbedLoader', | |||
'DataInfo', | |||
'DataSetLoader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
'ConllLoader', | |||
'SNLILoader', | |||
'SSTLoader', | |||
'PeopleDailyCorpusLoader', | |||
'Conll2003Loader', | |||
'ModelLoader', | |||
'ModelSaver', | |||
'SSTLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||
from .base_loader import DataInfo, DataSetLoader | |||
from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .model_io import ModelLoader, ModelSaver | |||
from .data_loader.sst import SSTLoader | |||
from .data_loader.matching import MatchingLoader, SNLILoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -10,6 +10,7 @@ from typing import Union, Dict | |||
import os | |||
from ..core.dataset import DataSet | |||
class BaseLoader(object): | |||
""" | |||
各个 Loader 的基类,提供了 API 的参考。 | |||
@@ -55,8 +56,6 @@ class BaseLoader(object): | |||
return obj | |||
def _download_from_url(url, path): | |||
try: | |||
from tqdm.auto import tqdm | |||
@@ -115,13 +114,11 @@ class DataInfo: | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
""" | |||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||
self.vocabs = vocabs or {} | |||
self.embeddings = embeddings or {} | |||
self.datasets = datasets or {} | |||
def __repr__(self): | |||
@@ -133,6 +130,7 @@ class DataInfo: | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
@@ -213,7 +211,6 @@ class DataSetLoader: | |||
返回的 :class:`DataInfo` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- embeddings: (可选) 数据集对应的词嵌入 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
@@ -0,0 +1,19 @@ | |||
""" | |||
用于读数据集的模块, 具体包括: | |||
这些模块的使用方法如下: | |||
""" | |||
__all__ = [ | |||
'SSTLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
] | |||
from .sst import SSTLoader | |||
from .matching import MatchingLoader, SNLILoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -0,0 +1,430 @@ | |||
import os | |||
from typing import Union, Dict | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..dataset_loader import JsonLoader, CSVLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder._bert import BertTokenizer | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
def __init__(self, paths: dict=None): | |||
self.paths = paths | |||
def _load(self, path): | |||
""" | |||
:param str path: 待读取数据集的路径名 | |||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||
的原始字符串文本,第三个为标签 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
对应的全路径文件名。 | |||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||
这个数据集的名字,如果不定义则默认为train。 | |||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
:param str auto_pad_token: 自动pad的内容 | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||
:return: | |||
""" | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
set_target = [set_target] | |||
if isinstance(set_input, bool): | |||
auto_set_input = set_input | |||
else: | |||
auto_set_input = False | |||
if isinstance(set_target, bool): | |||
auto_set_target = set_target | |||
else: | |||
auto_set_target = False | |||
if isinstance(paths, str): | |||
if os.path.isdir(paths): | |||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||
else: | |||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||
else: | |||
path = paths | |||
data_info = DataInfo() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.set_target(Const.TARGET) | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||
lines = f.readlines() | |||
lines = [line.strip() for line in lines] | |||
words_vocab.add_word_lst(lines) | |||
words_vocab.build_vocab() | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
if isinstance(concat, bool): | |||
concat = 'default' if concat else None | |||
if concat is not None: | |||
if isinstance(concat, str): | |||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||
'default': ['', '<sep>', '', '']} | |||
if concat.lower() in CONCAT_MAP: | |||
concat = CONCAT_MAP[concat] | |||
else: | |||
concat = 4 * [concat] | |||
assert len(concat) == 4, \ | |||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||
f'sentence. Your input is {concat}' | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||
is_input=auto_set_input) | |||
if seq_len_type is not None: | |||
if seq_len_type == 'seq_len': # | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
if Const.INPUT not in data_set.get_field_names(): | |||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||
f'got {data_set.get_field_names()}') | |||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if auto_pad_length is not None: | |||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary(padding=auto_pad_token) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||
if 'train' not in n]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if auto_pad_length is not None: | |||
if seq_len_type == 'seq_len': | |||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||
new_field_name=fields, is_input=auto_set_input) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
if isinstance(set_target, list): | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = JsonLoader._load(self, path) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -1,11 +1,14 @@ | |||
from typing import Iterable | |||
from nltk import Tree | |||
import spacy | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
spacy.prefer_gpu() | |||
sptk = spacy.load('en') | |||
class SSTLoader(DataSetLoader): | |||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
@@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader): | |||
def _get_one(data, subtree): | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||
return [(tree.leaves(), tree.label())] | |||
return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||
return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||
def process(self, | |||
paths, | |||
@@ -16,8 +16,6 @@ __all__ = [ | |||
'CSVLoader', | |||
'JsonLoader', | |||
'ConllLoader', | |||
'SNLILoader', | |||
'SSTLoader', | |||
'PeopleDailyCorpusLoader', | |||
'Conll2003Loader', | |||
] | |||
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
from .base_loader import DataSetLoader, DataInfo | |||
from .data_loader.sst import SSTLoader | |||
from ..core.const import Const | |||
from ..modules.encoder._bert import BertTokenizer | |||
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
else: | |||
instance = Instance(words=sent_words) | |||
data_set.append(instance) | |||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||
return data_set | |||
@@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader): | |||
return ds | |||
class SNLILoader(JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self): | |||
fields = { | |||
'sentence1_parse': Const.INPUTS(0), | |||
'sentence2_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
super(SNLILoader, self).__init__(fields=fields) | |||
def _load(self, path): | |||
ds = super(SNLILoader, self)._load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class CSVLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||
@@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
except Exception as e: | |||
if dropna: | |||
continue | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
raise ValueError('invalid instance ends at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
@@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
except Exception as e: | |||
if dropna: | |||
return | |||
print('invalid instance at line: {}'.format(line_idx)) | |||
print('invalid instance ends at line: {}'.format(line_idx)) | |||
raise e |
@@ -8,35 +8,7 @@ from torch import nn | |||
from .base_model import BaseModel | |||
from ..core.const import Const | |||
from ..modules.encoder import BertModel | |||
class BertConfig: | |||
def __init__( | |||
self, | |||
vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02 | |||
): | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.intermediate_size = intermediate_size | |||
self.hidden_act = hidden_act | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
from ..modules.encoder._bert import BertConfig | |||
class BertForSequenceClassification(BaseModel): | |||
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel): | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
config = BertConfig(30522) | |||
self.bert = BertModel(config) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, num_labels) | |||
@classmethod | |||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||
config = BertConfig(pretrained_model_dir) | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
pooled_output = self.dropout(pooled_output) | |||
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel): | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
config = BertConfig(30522) | |||
self.bert = BertModel(config) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, 1) | |||
@classmethod | |||
def from_pretrained(cls, num_choices, pretrained_model_dir): | |||
config = BertConfig(pretrained_model_dir) | |||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | |||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | |||
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel): | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
config = BertConfig(30522) | |||
self.bert = BertModel(config) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, num_labels) | |||
@classmethod | |||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||
config = BertConfig(pretrained_model_dir) | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
sequence_output = self.dropout(sequence_output) | |||
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel): | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
config = BertConfig(30522) | |||
self.bert = BertModel(config) | |||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | |||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_dir): | |||
config = BertConfig(pretrained_model_dir) | |||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
logits = self.qa_outputs(sequence_output) | |||
@@ -46,7 +46,7 @@ class StarTransEnc(nn.Module): | |||
super(StarTransEnc, self).__init__() | |||
self.embedding = get_embeddings(init_embed) | |||
emb_dim = self.embedding.embedding_dim | |||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||
#self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||
self.emb_drop = nn.Dropout(emb_dropout) | |||
self.encoder = StarTransformer(hidden_size=hidden_size, | |||
num_layers=num_layers, | |||
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module): | |||
[batch, hidden] 全局 relay 节点, 详见论文 | |||
""" | |||
x = self.embedding(x) | |||
x = self.emb_fc(self.emb_drop(x)) | |||
#x = self.emb_fc(self.emb_drop(x)) | |||
nodes, relay = self.encoder(x, mask) | |||
return nodes, relay | |||
@@ -205,7 +205,7 @@ class STSeqCls(nn.Module): | |||
max_len=max_len, | |||
emb_dropout=emb_dropout, | |||
dropout=dropout) | |||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | |||
def forward(self, words, seq_len): | |||
""" | |||
@@ -15,7 +15,8 @@ class MLP(nn.Module): | |||
多层感知器 | |||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||
sigmoid,默认值为relu | |||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||
:param str initial_method: 参数初始化方式 | |||
:param float dropout: dropout概率,默认值为0 | |||
@@ -2,7 +2,8 @@ | |||
""" | |||
这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 | |||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | |||
有用,也请引用一下他们。 | |||
""" | |||
@@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary | |||
import collections | |||
import unicodedata | |||
from ...io.file_utils import _get_base_url, cached_path | |||
import numpy as np | |||
from itertools import chain | |||
import copy | |||
@@ -22,9 +22,106 @@ import os | |||
import torch | |||
from torch import nn | |||
import glob | |||
import sys | |||
CONFIG_FILE = 'bert_config.json' | |||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||
class BertConfig(object): | |||
"""Configuration class to store the configuration of a `BertModel`. | |||
""" | |||
def __init__(self, | |||
vocab_size_or_config_json_file, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02, | |||
layer_norm_eps=1e-12): | |||
"""Constructs BertConfig. | |||
Args: | |||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | |||
hidden_size: Size of the encoder layers and the pooler layer. | |||
num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||
num_attention_heads: Number of attention heads for each attention layer in | |||
the Transformer encoder. | |||
intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |||
layer in the Transformer encoder. | |||
hidden_act: The non-linear activation function (function or string) in the | |||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |||
hidden_dropout_prob: The dropout probabilitiy for all fully connected | |||
layers in the embeddings, encoder, and pooler. | |||
attention_probs_dropout_prob: The dropout ratio for the attention | |||
probabilities. | |||
max_position_embeddings: The maximum sequence length that this model might | |||
ever be used with. Typically set this to something large just in case | |||
(e.g., 512 or 1024 or 2048). | |||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into | |||
`BertModel`. | |||
initializer_range: The sttdev of the truncated_normal_initializer for | |||
initializing all weight matrices. | |||
layer_norm_eps: The epsilon used by LayerNorm. | |||
""" | |||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 | |||
and isinstance(vocab_size_or_config_json_file, unicode)): | |||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: | |||
json_config = json.loads(reader.read()) | |||
for key, value in json_config.items(): | |||
self.__dict__[key] = value | |||
elif isinstance(vocab_size_or_config_json_file, int): | |||
self.vocab_size = vocab_size_or_config_json_file | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
self.layer_norm_eps = layer_norm_eps | |||
else: | |||
raise ValueError("First argument must be either a vocabulary size (int)" | |||
"or the path to a pretrained model config file (str)") | |||
@classmethod | |||
def from_dict(cls, json_object): | |||
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||
config = BertConfig(vocab_size_or_config_json_file=-1) | |||
for key, value in json_object.items(): | |||
config.__dict__[key] = value | |||
return config | |||
@classmethod | |||
def from_json_file(cls, json_file): | |||
"""Constructs a `BertConfig` from a json file of parameters.""" | |||
with open(json_file, "r", encoding='utf-8') as reader: | |||
text = reader.read() | |||
return cls.from_dict(json.loads(text)) | |||
def __repr__(self): | |||
return str(self.to_json_string()) | |||
def to_dict(self): | |||
"""Serializes this instance to a Python dictionary.""" | |||
output = copy.deepcopy(self.__dict__) | |||
return output | |||
def to_json_string(self): | |||
"""Serializes this instance to a JSON string.""" | |||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |||
def to_json_file(self, json_file_path): | |||
""" Save this instance to a json file.""" | |||
with open(json_file_path, "w", encoding='utf-8') as writer: | |||
writer.write(self.to_json_string()) | |||
def gelu(x): | |||
@@ -40,6 +137,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
class BertLayerNorm(nn.Module): | |||
def __init__(self, hidden_size, eps=1e-12): | |||
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||
""" | |||
super(BertLayerNorm, self).__init__() | |||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
@@ -53,16 +152,18 @@ class BertLayerNorm(nn.Module): | |||
class BertEmbeddings(nn.Module): | |||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||
"""Construct the embeddings from word, position and token_type embeddings. | |||
""" | |||
def __init__(self, config): | |||
super(BertEmbeddings, self).__init__() | |||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) | |||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | |||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
# any TensorFlow checkpoint file | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, input_ids, token_type_ids=None): | |||
seq_length = input_ids.size(1) | |||
@@ -82,21 +183,21 @@ class BertEmbeddings(nn.Module): | |||
class BertSelfAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||
def __init__(self, config): | |||
super(BertSelfAttention, self).__init__() | |||
if hidden_size % num_attention_heads != 0: | |||
if config.hidden_size % config.num_attention_heads != 0: | |||
raise ValueError( | |||
"The hidden size (%d) is not a multiple of the number of attention " | |||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||
self.num_attention_heads = num_attention_heads | |||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) | |||
self.num_attention_heads = config.num_attention_heads | |||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
@@ -133,11 +234,11 @@ class BertSelfAttention(nn.Module): | |||
class BertSelfOutput(nn.Module): | |||
def __init__(self, hidden_size, hidden_dropout_prob): | |||
def __init__(self, config): | |||
super(BertSelfOutput, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
@@ -147,10 +248,10 @@ class BertSelfOutput(nn.Module): | |||
class BertAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||
def __init__(self, config): | |||
super(BertAttention, self).__init__() | |||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||
self.self = BertSelfAttention(config) | |||
self.output = BertSelfOutput(config) | |||
def forward(self, input_tensor, attention_mask): | |||
self_output = self.self(input_tensor, attention_mask) | |||
@@ -159,11 +260,13 @@ class BertAttention(nn.Module): | |||
class BertIntermediate(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||
def __init__(self, config): | |||
super(BertIntermediate, self).__init__() | |||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||
if isinstance(hidden_act, str) else hidden_act | |||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): | |||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.intermediate_act_fn = config.hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
@@ -172,11 +275,11 @@ class BertIntermediate(nn.Module): | |||
class BertOutput(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||
def __init__(self, config): | |||
super(BertOutput, self).__init__() | |||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
@@ -186,13 +289,11 @@ class BertOutput(nn.Module): | |||
class BertLayer(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
def __init__(self, config): | |||
super(BertLayer, self).__init__() | |||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob) | |||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||
self.attention = BertAttention(config) | |||
self.intermediate = BertIntermediate(config) | |||
self.output = BertOutput(config) | |||
def forward(self, hidden_states, attention_mask): | |||
attention_output = self.attention(hidden_states, attention_mask) | |||
@@ -202,13 +303,10 @@ class BertLayer(nn.Module): | |||
class BertEncoder(nn.Module): | |||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
def __init__(self, config): | |||
super(BertEncoder, self).__init__() | |||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act) | |||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||
layer = BertLayer(config) | |||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) | |||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
all_encoder_layers = [] | |||
@@ -222,9 +320,9 @@ class BertEncoder(nn.Module): | |||
class BertPooler(nn.Module): | |||
def __init__(self, hidden_size): | |||
def __init__(self, config): | |||
super(BertPooler, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
@@ -242,13 +340,19 @@ class BertModel(nn.Module): | |||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||
sources:: | |||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", | |||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", | |||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", | |||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", | |||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", | |||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", | |||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", | |||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", | |||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", | |||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", | |||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" | |||
用预训练权重矩阵来建立BERT模型:: | |||
@@ -272,34 +376,30 @@ class BertModel(nn.Module): | |||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||
""" | |||
def __init__(self, vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02): | |||
def __init__(self, config, *inputs, **kwargs): | |||
super(BertModel, self).__init__() | |||
self.hidden_size = hidden_size | |||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||
type_vocab_size, hidden_dropout_prob) | |||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||
hidden_act) | |||
self.pooler = BertPooler(hidden_size) | |||
self.initializer_range = initializer_range | |||
if not isinstance(config, BertConfig): | |||
raise ValueError( | |||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. " | |||
"To create a model from a Google pretrained model use " | |||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |||
self.__class__.__name__, self.__class__.__name__ | |||
)) | |||
super(BertModel, self).__init__() | |||
self.config = config | |||
self.hidden_size = self.config.hidden_size | |||
self.embeddings = BertEmbeddings(config) | |||
self.encoder = BertEncoder(config) | |||
self.pooler = BertPooler(config) | |||
self.apply(self.init_bert_weights) | |||
def init_bert_weights(self, module): | |||
""" Initialize the weights. | |||
""" | |||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |||
elif isinstance(module, BertLayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
@@ -338,14 +438,19 @@ class BertModel(nn.Module): | |||
return encoded_layers, pooled_output | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||
def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): | |||
state_dict = kwargs.get('state_dict', None) | |||
kwargs.pop('state_dict', None) | |||
cache_dir = kwargs.get('cache_dir', None) | |||
kwargs.pop('cache_dir', None) | |||
from_tf = kwargs.get('from_tf', False) | |||
kwargs.pop('from_tf', None) | |||
# Load config | |||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||
config = json.load(open(config_file, "r")) | |||
# config = BertConfig.from_json_file(config_file) | |||
config = BertConfig.from_json_file(config_file) | |||
# logger.info("Model config {}".format(config)) | |||
# Instantiate model. | |||
model = cls(*inputs, **config, **kwargs) | |||
model = cls(config, *inputs, **kwargs) | |||
if state_dict is None: | |||
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) | |||
if len(files)==0: | |||
@@ -353,7 +458,7 @@ class BertModel(nn.Module): | |||
elif len(files)>1: | |||
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") | |||
weights_path = files[0] | |||
state_dict = torch.load(weights_path) | |||
state_dict = torch.load(weights_path, map_location='cpu') | |||
old_keys = [] | |||
new_keys = [] | |||
@@ -464,6 +569,7 @@ class WordpieceTokenizer(object): | |||
output_tokens.extend(sub_tokens) | |||
return output_tokens | |||
def load_vocab(vocab_file): | |||
"""Loads a vocabulary file into a dictionary.""" | |||
vocab = collections.OrderedDict() | |||
@@ -594,6 +700,7 @@ class BasicTokenizer(object): | |||
output.append(char) | |||
return "".join(output) | |||
def _is_whitespace(char): | |||
"""Checks whether `chars` is a whitespace character.""" | |||
# \t, \n, and \r are technically contorl characters but we treat them | |||
@@ -840,6 +947,7 @@ class _WordBertModel(nn.Module): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | |||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :len(word_pieces_i)+2].fill_(1) | |||
# TODO 截掉长度超过的部分。 | |||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | |||
@@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding): | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 读取embedding | |||
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||
normalize=normalize) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
sparse=False, _weight=embedding) | |||
if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk | |||
words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||
for word, idx in vocab: | |||
if vocab._is_word_no_create_entry(word) and not hit_flags[idx]: | |||
words_to_words[idx] = vocab.unknown_idx | |||
self.words_to_words = words_to_words | |||
self._embed_size = self.embedding.weight.size(1) | |||
self.requires_grad = requires_grad | |||
@@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding): | |||
else: | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = torch.zeros(len(vocab), dim) | |||
if init_method is not None: | |||
init_method(matrix) | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
matrix = {} | |||
found_count = 0 | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
@@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding): | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
hit_flags[index] = True | |||
found_count += 1 | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
found_count = sum(hit_flags) | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
if init_method is None: | |||
if len(vocab)-found_count>0 and found_count>0: # 有的没找到 | |||
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] | |||
mean = found_vecs.mean(dim=0, keepdim=True) | |||
std = found_vecs.std(dim=0, keepdim=True) | |||
unfound_vec_num = np.sum(hit_flags==False) | |||
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean | |||
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
vectors = torch.zeros(len(matrix), dim) | |||
if init_method: | |||
init_method(vectors) | |||
else: | |||
nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim)) | |||
if vocab._no_create_word_length>0: | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() | |||
else: | |||
unknown_idx = vocab.unknown_idx | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
for order, (index, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[order] = vec | |||
words_to_words[index] = order | |||
self.words_to_words = words_to_words | |||
else: | |||
for index, vec in matrix.items(): | |||
if vec is not None: | |||
vectors[index] = vec | |||
if normalize: | |||
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) | |||
vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12) | |||
return matrix, hit_flags | |||
return vectors | |||
def forward(self, words): | |||
""" | |||
@@ -35,11 +35,13 @@ class StarTransformer(nn.Module): | |||
self.iters = num_layers | |||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | |||
self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||
self.emb_drop = nn.Dropout(dropout) | |||
self.ring_att = nn.ModuleList( | |||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
for _ in range(self.iters)]) | |||
self.star_att = nn.ModuleList( | |||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
for _ in range(self.iters)]) | |||
if max_len is not None: | |||
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module): | |||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | |||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||
if self.pos_emb: | |||
if self.pos_emb and False: | |||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||
embs = embs + P | |||
embs = norm_func(self.emb_drop, embs) | |||
nodes = embs | |||
relay = embs.mean(2, keepdim=True) | |||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | |||
r_embs = embs.view(B, H, 1, L) | |||
for i in range(self.iters): | |||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | |||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | |||
nodes = nodes.masked_fill_(ex_mask, 0) | |||
@@ -3,6 +3,8 @@ | |||
复现的模型有: | |||
- [Star-Transformer](Star_transformer/) | |||
- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239) | |||
- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12) | |||
- ... | |||
# 任务复现 | |||
@@ -11,11 +13,11 @@ | |||
## Matching (自然语言推理/句子匹配) | |||
- still in progress | |||
- [Matching 任务复现](matching) | |||
## Sequence Labeling (序列标注) | |||
- still in progress | |||
- [NER](seqence_labelling/ner) | |||
## Coreference resolution (指代消解) | |||
@@ -2,7 +2,8 @@ import torch | |||
import json | |||
import os | |||
from fastNLP import Vocabulary | |||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader | |||
from fastNLP.io.dataset_loader import ConllLoader | |||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||
from fastNLP.core import Const as C | |||
import numpy as np | |||
@@ -50,13 +51,15 @@ def load_sst(path, files): | |||
for sub in [True, False, False]] | |||
ds_list = [loader.load(os.path.join(path, fn)) | |||
for fn, loader in zip(files, loaders)] | |||
word_v = Vocabulary(min_freq=2) | |||
word_v = Vocabulary(min_freq=0) | |||
tag_v = Vocabulary(unknown=None, padding=None) | |||
for ds in ds_list: | |||
ds.apply(lambda x: [w.lower() | |||
for w in x['words']], new_field_name='words') | |||
ds_list[0].drop(lambda x: len(x['words']) < 3) | |||
#ds_list[0].drop(lambda x: len(x['words']) < 3) | |||
update_v(word_v, ds_list[0], 'words') | |||
update_v(word_v, ds_list[1], 'words') | |||
update_v(word_v, ds_list[2], 'words') | |||
ds_list[0].apply(lambda x: tag_v.add_word( | |||
x['target']), new_field_name=None) | |||
@@ -151,7 +154,10 @@ class EmbedLoader: | |||
# some words from vocab are missing in pre-trained embedding | |||
# we normally sample each dimension | |||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||
#sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||
# size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
sampled_vectors = np.random.uniform(-0.01, 0.01, | |||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
return embedding_matrix |
@@ -1,5 +1,5 @@ | |||
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | |||
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | |||
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & | |||
python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log & | |||
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | |||
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & | |||
#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & |
@@ -1,4 +1,6 @@ | |||
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | |||
seed = set_rng_seeds(15360) | |||
print('RNG SEED {}'.format(seed)) | |||
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | |||
import torch.nn as nn | |||
import torch | |||
@@ -7,8 +9,8 @@ import fastNLP as FN | |||
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | |||
from fastNLP.core.const import Const as C | |||
import sys | |||
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||
#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||
pre_dir = '/home/ec2-user/fast_data/' | |||
g_model_select = { | |||
'pos': STSeqLabel, | |||
@@ -17,8 +19,8 @@ g_model_select = { | |||
'nli': STNLICls, | |||
} | |||
g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', | |||
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} | |||
g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt', | |||
'zh': pre_dir + 'cc.zh.300.vec'} | |||
g_args = None | |||
g_model_cfg = None | |||
@@ -53,7 +55,7 @@ def get_conll2012_ner(): | |||
def get_sst(): | |||
path = '/remote-home/yfshao/workdir/datasets/SST' | |||
path = pre_dir + 'sst' | |||
files = ['train.txt', 'dev.txt', 'test.txt'] | |||
return load_sst(path, files) | |||
@@ -94,6 +96,7 @@ class MyCallback(FN.core.callback.Callback): | |||
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | |||
def on_step_end(self): | |||
return | |||
warm_steps = 6000 | |||
# learning rate warm-up & decay | |||
if self.step <= warm_steps: | |||
@@ -108,12 +111,11 @@ class MyCallback(FN.core.callback.Callback): | |||
def train(): | |||
seed = set_rng_seeds(1234) | |||
print('RNG SEED {}'.format(seed)) | |||
print('loading data') | |||
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | |||
g_args.ds, g_args.task)]() | |||
print(ds_list[0][:2]) | |||
print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2])) | |||
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | |||
g_model_cfg['num_cls'] = len(tag_v) | |||
print(g_model_cfg) | |||
@@ -123,11 +125,14 @@ def train(): | |||
def init_model(model): | |||
for p in model.parameters(): | |||
if p.size(0) != len(word_v): | |||
nn.init.normal_(p, 0.0, 0.05) | |||
if len(p.size())<2: | |||
nn.init.constant_(p, 0.0) | |||
else: | |||
nn.init.normal_(p, 0.0, 0.05) | |||
init_model(model) | |||
train_data = ds_list[0] | |||
dev_data = ds_list[2] | |||
test_data = ds_list[1] | |||
dev_data = ds_list[1] | |||
test_data = ds_list[2] | |||
print(tag_v.word2idx) | |||
if g_args.task in ['pos', 'ner']: | |||
@@ -145,14 +150,26 @@ def train(): | |||
} | |||
metric_key, metric = metrics[g_args.task] | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
ex_param = [x for x in model.parameters( | |||
) if x.requires_grad and x.size(0) != len(word_v)] | |||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||
device=device, callbacks=[MyCallback()]) | |||
params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)] | |||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | |||
print([n for n,p in params]) | |||
optim_cfg = [ | |||
#{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay}, | |||
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay} | |||
] | |||
print(model) | |||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
loss=loss, metrics=metric, metric_key=metric_key, | |||
optimizer=torch.optim.Adam(optim_cfg), | |||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000, | |||
device=device, | |||
use_tqdm=False, prefetch=False, | |||
save_path=g_args.log, | |||
sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN), | |||
callbacks=[MyCallback()]) | |||
trainer.train() | |||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | |||
@@ -195,12 +212,12 @@ def main(): | |||
'init_embed': (None, 300), | |||
'num_cls': None, | |||
'hidden_size': g_args.hidden, | |||
'num_layers': 4, | |||
'num_layers': 2, | |||
'num_head': g_args.nhead, | |||
'head_dim': g_args.hdim, | |||
'max_len': MAX_LEN, | |||
'cls_hidden_size': 600, | |||
'emb_dropout': 0.3, | |||
'cls_hidden_size': 200, | |||
'emb_dropout': g_args.drop, | |||
'dropout': g_args.drop, | |||
} | |||
run_select[g_args.mode.lower()]() | |||
@@ -0,0 +1,68 @@ | |||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||
from fastNLP.io.file_reader import _read_json | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from reproduction.coreference_resolution.model.config import Config | |||
import reproduction.coreference_resolution.model.preprocess as preprocess | |||
class CRLoader(JsonLoader): | |||
def __init__(self, fields=None, dropna=False): | |||
super().__init__(fields, dropna) | |||
def _load(self, path): | |||
""" | |||
加载数据 | |||
:param path: | |||
:return: | |||
""" | |||
dataset = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
dataset.append(Instance(**ins)) | |||
return dataset | |||
def process(self, paths, **kwargs): | |||
data_info = DataInfo() | |||
for name in ['train', 'test', 'dev']: | |||
data_info.datasets[name] = self.load(paths[name]) | |||
config = Config() | |||
vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') | |||
vocab.build_vocab() | |||
word2id = vocab.word2idx | |||
char_dict = preprocess.get_char_dict(config.char_path) | |||
data_info.vocabs = vocab | |||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||
for name, ds in data_info.datasets.items(): | |||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||
config.max_sentences, is_train=name=='train')[0], | |||
new_field_name='doc_np') | |||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||
config.max_sentences, is_train=name=='train')[1], | |||
new_field_name='char_index') | |||
ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||
config.max_sentences, is_train=name=='train')[2], | |||
new_field_name='seq_len') | |||
ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), | |||
new_field_name='speaker_ids_np') | |||
ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') | |||
ds.set_ignore_type('clusters') | |||
ds.set_padder('clusters', None) | |||
ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") | |||
ds.set_target("clusters") | |||
# train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) | |||
# train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) | |||
return data_info | |||
@@ -0,0 +1,54 @@ | |||
class Config(): | |||
def __init__(self): | |||
self.is_training = True | |||
# path | |||
self.glove = 'data/glove.840B.300d.txt.filtered' | |||
self.turian = 'data/turian.50d.txt' | |||
self.train_path = "data/train.english.jsonlines" | |||
self.dev_path = "data/dev.english.jsonlines" | |||
self.test_path = "data/test.english.jsonlines" | |||
self.char_path = "data/char_vocab.english.txt" | |||
self.cuda = "0" | |||
self.max_word = 1500 | |||
self.epoch = 200 | |||
# config | |||
# self.use_glove = True | |||
# self.use_turian = True #No | |||
self.use_elmo = False | |||
self.use_CNN = True | |||
self.model_heads = True #Yes | |||
self.use_width = True # Yes | |||
self.use_distance = True #Yes | |||
self.use_metadata = True #Yes | |||
self.mention_ratio = 0.4 | |||
self.max_sentences = 50 | |||
self.span_width = 10 | |||
self.feature_size = 20 #宽度信息emb的size | |||
self.lr = 0.001 | |||
self.lr_decay = 1e-3 | |||
self.max_antecedents = 100 # 这个参数在mention detection中没有用 | |||
self.atten_hidden_size = 150 | |||
self.mention_hidden_size = 150 | |||
self.sa_hidden_size = 150 | |||
self.char_emb_size = 8 | |||
self.filter = [3,4,5] | |||
# decay = 1e-5 | |||
def __str__(self): | |||
d = self.__dict__ | |||
out = 'config==============\n' | |||
for i in list(d): | |||
out += i+":" | |||
out += str(d[i])+"\n" | |||
out+="config==============\n" | |||
return out | |||
if __name__=="__main__": | |||
config = Config() | |||
print(config) |
@@ -0,0 +1,163 @@ | |||
from fastNLP.core.metrics import MetricBase | |||
import numpy as np | |||
from collections import Counter | |||
from sklearn.utils.linear_assignment_ import linear_assignment | |||
""" | |||
Mostly borrowed from https://github.com/clarkkev/deep-coref/blob/master/evaluation.py | |||
""" | |||
class CRMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] | |||
# TODO 改名为evaluate,输入也 | |||
def evaluate(self, predicted, mention_to_predicted,clusters): | |||
for e in self.evaluators: | |||
e.update(predicted,mention_to_predicted, clusters) | |||
def get_f1(self): | |||
return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) | |||
def get_recall(self): | |||
return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) | |||
def get_precision(self): | |||
return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) | |||
# TODO 原本的getprf | |||
def get_metric(self,reset=False): | |||
res = {"pre":self.get_precision(), "rec":self.get_recall(), "f":self.get_f1()} | |||
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] | |||
return res | |||
class Evaluator(): | |||
def __init__(self, metric, beta=1): | |||
self.p_num = 0 | |||
self.p_den = 0 | |||
self.r_num = 0 | |||
self.r_den = 0 | |||
self.metric = metric | |||
self.beta = beta | |||
def update(self, predicted,mention_to_predicted,gold): | |||
gold = gold[0].tolist() | |||
gold = [tuple(tuple(m) for m in gc) for gc in gold] | |||
mention_to_gold = {} | |||
for gc in gold: | |||
for mention in gc: | |||
mention_to_gold[mention] = gc | |||
if self.metric == ceafe: | |||
pn, pd, rn, rd = self.metric(predicted, gold) | |||
else: | |||
pn, pd = self.metric(predicted, mention_to_gold) | |||
rn, rd = self.metric(gold, mention_to_predicted) | |||
self.p_num += pn | |||
self.p_den += pd | |||
self.r_num += rn | |||
self.r_den += rd | |||
def get_f1(self): | |||
return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) | |||
def get_recall(self): | |||
return 0 if self.r_num == 0 else self.r_num / float(self.r_den) | |||
def get_precision(self): | |||
return 0 if self.p_num == 0 else self.p_num / float(self.p_den) | |||
def get_prf(self): | |||
return self.get_precision(), self.get_recall(), self.get_f1() | |||
def get_counts(self): | |||
return self.p_num, self.p_den, self.r_num, self.r_den | |||
def b_cubed(clusters, mention_to_gold): | |||
num, dem = 0, 0 | |||
for c in clusters: | |||
if len(c) == 1: | |||
continue | |||
gold_counts = Counter() | |||
correct = 0 | |||
for m in c: | |||
if m in mention_to_gold: | |||
gold_counts[tuple(mention_to_gold[m])] += 1 | |||
for c2, count in gold_counts.items(): | |||
if len(c2) != 1: | |||
correct += count * count | |||
num += correct / float(len(c)) | |||
dem += len(c) | |||
return num, dem | |||
def muc(clusters, mention_to_gold): | |||
tp, p = 0, 0 | |||
for c in clusters: | |||
p += len(c) - 1 | |||
tp += len(c) | |||
linked = set() | |||
for m in c: | |||
if m in mention_to_gold: | |||
linked.add(mention_to_gold[m]) | |||
else: | |||
tp -= 1 | |||
tp -= len(linked) | |||
return tp, p | |||
def phi4(c1, c2): | |||
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) | |||
def ceafe(clusters, gold_clusters): | |||
clusters = [c for c in clusters if len(c) != 1] | |||
scores = np.zeros((len(gold_clusters), len(clusters))) | |||
for i in range(len(gold_clusters)): | |||
for j in range(len(clusters)): | |||
scores[i, j] = phi4(gold_clusters[i], clusters[j]) | |||
matching = linear_assignment(-scores) | |||
similarity = sum(scores[matching[:, 0], matching[:, 1]]) | |||
return similarity, len(clusters), similarity, len(gold_clusters) | |||
def lea(clusters, mention_to_gold): | |||
num, dem = 0, 0 | |||
for c in clusters: | |||
if len(c) == 1: | |||
continue | |||
common_links = 0 | |||
all_links = len(c) * (len(c) - 1) / 2.0 | |||
for i, m in enumerate(c): | |||
if m in mention_to_gold: | |||
for m2 in c[i + 1:]: | |||
if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: | |||
common_links += 1 | |||
num += len(c) * common_links / float(all_links) | |||
dem += len(c) | |||
return num, dem | |||
def f1(p_num, p_den, r_num, r_den, beta=1): | |||
p = 0 if p_den == 0 else p_num / float(p_den) | |||
r = 0 if r_den == 0 else r_num / float(r_den) | |||
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) |
@@ -0,0 +1,576 @@ | |||
import torch | |||
import numpy as np | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from allennlp.commands.elmo import ElmoEmbedder | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from reproduction.coreference_resolution.model import preprocess | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
import random | |||
# 设置seed | |||
torch.manual_seed(0) # cpu | |||
torch.cuda.manual_seed(0) # gpu | |||
np.random.seed(0) # numpy | |||
random.seed(0) | |||
class ffnn(nn.Module): | |||
def __init__(self, input_size, hidden_size, output_size): | |||
super(ffnn, self).__init__() | |||
self.f = nn.Sequential( | |||
# 多少层数 | |||
nn.Linear(input_size, hidden_size), | |||
nn.ReLU(inplace=True), | |||
nn.Dropout(p=0.2), | |||
nn.Linear(hidden_size, hidden_size), | |||
nn.ReLU(inplace=True), | |||
nn.Dropout(p=0.2), | |||
nn.Linear(hidden_size, output_size) | |||
) | |||
self.reset_param() | |||
def reset_param(self): | |||
for name, param in self.named_parameters(): | |||
if param.dim() > 1: | |||
nn.init.xavier_normal_(param) | |||
# param.data = torch.tensor(np.random.randn(*param.shape)).float() | |||
else: | |||
nn.init.zeros_(param) | |||
def forward(self, input): | |||
return self.f(input).squeeze() | |||
class Model(BaseModel): | |||
def __init__(self, vocab, config): | |||
word2id = vocab.word2idx | |||
super(Model, self).__init__() | |||
vocab_num = len(word2id) | |||
self.word2id = word2id | |||
self.config = config | |||
self.char_dict = preprocess.get_char_dict('data/char_vocab.english.txt') | |||
self.genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||
self.device = torch.device("cuda:" + config.cuda) | |||
self.emb = nn.Embedding(vocab_num, 350) | |||
emb1 = EmbedLoader().load_with_vocab(config.glove, vocab,normalize=False) | |||
emb2 = EmbedLoader().load_with_vocab(config.turian, vocab ,normalize=False) | |||
pre_emb = np.concatenate((emb1, emb2), axis=1) | |||
pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12) | |||
if pre_emb is not None: | |||
self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float()) | |||
for param in self.emb.parameters(): | |||
param.requires_grad = False | |||
self.emb_dropout = nn.Dropout(inplace=True) | |||
if config.use_elmo: | |||
self.elmo = ElmoEmbedder(options_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json', | |||
weight_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', | |||
cuda_device=int(config.cuda)) | |||
print("elmo load over.") | |||
self.elmo_args = torch.randn((3), requires_grad=True).to(self.device) | |||
self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size) | |||
self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3) | |||
self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4) | |||
self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5) | |||
self.feature_emb = nn.Embedding(config.span_width, config.feature_size) | |||
self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True) | |||
self.mention_distance_emb = nn.Embedding(10, config.feature_size) | |||
self.distance_drop = nn.Dropout(p=0.2, inplace=True) | |||
self.genre_emb = nn.Embedding(7, config.feature_size) | |||
self.speaker_emb = nn.Embedding(2, config.feature_size) | |||
self.bilstm = VarLSTM(input_size=350+150*config.use_CNN+config.use_elmo*1024,hidden_size=200,bidirectional=True,batch_first=True,hidden_dropout=0.2) | |||
# self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True) | |||
self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) | |||
self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) | |||
self.bilstm_drop = nn.Dropout(p=0.2, inplace=True) | |||
self.atten = ffnn(input_size=400, hidden_size=config.atten_hidden_size, output_size=1) | |||
self.mention_score = ffnn(input_size=1320, hidden_size=config.mention_hidden_size, output_size=1) | |||
self.sa = ffnn(input_size=3980+40*config.use_metadata, hidden_size=config.sa_hidden_size, output_size=1) | |||
self.mention_start_np = None | |||
self.mention_end_np = None | |||
def _reorder_lstm(self, word_emb, seq_lens): | |||
sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True) | |||
seq_lens_re = [seq_lens[i] for i in sort_ind] | |||
emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True) | |||
packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq, seq_lens_re, batch_first=True) | |||
h0 = self.h0.repeat(1, len(seq_lens), 1) | |||
c0 = self.c0.repeat(1, len(seq_lens), 1) | |||
packed_out, final_states = self.bilstm(packed_seq, (h0, c0)) | |||
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) | |||
back_map = {ind: i for i, ind in enumerate(sort_ind)} | |||
reorder_ind = [back_map[i] for i in range(len(seq_lens_re))] | |||
lstm_out = self.reorder_sequence(lstm_out, reorder_ind, batch_first=True) | |||
return lstm_out | |||
def reorder_sequence(self, sequence_emb, order, batch_first=True): | |||
""" | |||
sequence_emb: [T, B, D] if not batch_first | |||
order: list of sequence length | |||
""" | |||
batch_dim = 0 if batch_first else 1 | |||
assert len(order) == sequence_emb.size()[batch_dim] | |||
order = torch.LongTensor(order) | |||
order = order.to(sequence_emb).long() | |||
sorted_ = sequence_emb.index_select(index=order, dim=batch_dim) | |||
del order | |||
return sorted_ | |||
def flat_lstm(self, lstm_out, seq_lens): | |||
batch = lstm_out.shape[0] | |||
seq = lstm_out.shape[1] | |||
dim = lstm_out.shape[2] | |||
l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)] | |||
flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device)) | |||
return flatted | |||
def potential_mention_index(self, word_index, max_sent_len): | |||
# get mention index [3,2]:the first sentence is 3 and secend 2 | |||
# [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2) | |||
potential_mention = [] | |||
for i in range(len(word_index)): | |||
for j in range(i, i + max_sent_len): | |||
if (j < len(word_index) and word_index[i] == word_index[j]): | |||
potential_mention.append([i, j]) | |||
return potential_mention | |||
def get_mention_start_end(self, seq_lens): | |||
# 序列长度转换成mention | |||
# [3,2] --> [0,0,0,1,1] | |||
word_index = [0] * sum(seq_lens) | |||
sent_index = 0 | |||
index = 0 | |||
for length in seq_lens: | |||
for l in range(length): | |||
word_index[index] = sent_index | |||
index += 1 | |||
sent_index += 1 | |||
# [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....] | |||
mention_id = self.potential_mention_index(word_index, self.config.span_width) | |||
mention_start = np.array(mention_id, dtype=int)[:, 0] | |||
mention_end = np.array(mention_id, dtype=int)[:, 1] | |||
return mention_start, mention_end | |||
def get_mention_emb(self, flatten_lstm, mention_start, mention_end): | |||
mention_start_tensor = torch.from_numpy(mention_start).to(self.device) | |||
mention_end_tensor = torch.from_numpy(mention_end).to(self.device) | |||
emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor) # [mention_num,embed] | |||
emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor) # [mention_num,embed] | |||
return emb_start, emb_end | |||
def get_mask(self, mention_start, mention_end): | |||
# big mask for attention | |||
mention_num = mention_start.shape[0] | |||
mask = np.zeros((mention_num, self.config.span_width)) # [mention_num,span_width] | |||
for i in range(mention_num): | |||
start = mention_start[i] | |||
end = mention_end[i] | |||
# 实际上是宽度 | |||
for j in range(end - start + 1): | |||
mask[i][j] = 1 | |||
mask = torch.from_numpy(mask) # [mention_num,max_mention] | |||
# 0-->-inf 1-->0 | |||
log_mask = torch.log(mask) | |||
return log_mask | |||
def get_mention_index(self, mention_start, max_mention): | |||
# TODO 后面可能要改 | |||
assert len(mention_start.shape) == 1 | |||
mention_start_tensor = torch.from_numpy(mention_start) | |||
num_mention = mention_start_tensor.shape[0] | |||
mention_index = mention_start_tensor.expand(max_mention, num_mention).transpose(0, | |||
1) # [num_mention,max_mention] | |||
assert mention_index.shape[0] == num_mention | |||
assert mention_index.shape[1] == max_mention | |||
range_add = torch.arange(0, max_mention).expand(num_mention, max_mention).long() # [num_mention,max_mention] | |||
mention_index = mention_index + range_add | |||
mention_index = torch.min(mention_index, torch.LongTensor([mention_start[-1]]).expand(num_mention, max_mention)) | |||
return mention_index.to(self.device) | |||
def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens): | |||
# 排序记录,高分段在前面 | |||
mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True) | |||
preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens)) | |||
mention_ids = mention_ids[0:preserve_mention_num] | |||
mention_score = mention_score[0:preserve_mention_num] | |||
mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0, | |||
index=mention_ids) # [lamda*word_num] | |||
mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0, | |||
index=mention_ids) # [lamda*word_num] | |||
mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0) # [lamda*word_num,emb] | |||
assert mention_score.shape[0] == preserve_mention_num | |||
assert mention_start_tensor.shape[0] == preserve_mention_num | |||
assert mention_end_tensor.shape[0] == preserve_mention_num | |||
assert mention_emb.shape[0] == preserve_mention_num | |||
# TODO 不交叉没做处理 | |||
# 对start进行再排序,实际位置在前面 | |||
# TODO 这里只考虑了start没有考虑end | |||
mention_start_tensor, temp_index = torch.sort(mention_start_tensor) | |||
mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index) | |||
mention_emb = mention_emb.index_select(dim=0, index=temp_index) | |||
mention_score = mention_score.index_select(dim=0, index=temp_index) | |||
return mention_start_tensor, mention_end_tensor, mention_score, mention_emb | |||
def get_antecedents(self, mention_starts, max_antecedents): | |||
num_mention = mention_starts.shape[0] | |||
max_antecedents = min(max_antecedents, num_mention) | |||
# mention和它是第几个mention之间的对应关系 | |||
antecedents = np.zeros((num_mention, max_antecedents), dtype=int) # [num_mention,max_an] | |||
# 记录长度 | |||
antecedents_len = [0] * num_mention | |||
for i in range(num_mention): | |||
ante_count = 0 | |||
for j in range(max(0, i - max_antecedents), i): | |||
antecedents[i, ante_count] = j | |||
ante_count += 1 | |||
# 补位操作 | |||
for j in range(ante_count, max_antecedents): | |||
antecedents[i, j] = 0 | |||
antecedents_len[i] = ante_count | |||
assert antecedents.shape[1] == max_antecedents | |||
return antecedents, antecedents_len | |||
def get_antecedents_score(self, span_represent, mention_score, antecedents, antecedents_len, mention_speakers_ids, | |||
genre): | |||
num_mention = mention_score.shape[0] | |||
max_antecedent = antecedents.shape[1] | |||
pair_emb = self.get_pair_emb(span_represent, antecedents, mention_speakers_ids, genre) # [span_num,max_ant,emb] | |||
antecedent_scores = self.sa(pair_emb) | |||
mask01 = self.sequence_mask(antecedents_len, max_antecedent) | |||
maskinf = torch.log(mask01).to(self.device) | |||
assert maskinf.shape[1] <= max_antecedent | |||
assert antecedent_scores.shape[0] == num_mention | |||
antecedent_scores = antecedent_scores + maskinf | |||
antecedents = torch.from_numpy(antecedents).to(self.device) | |||
mention_scoreij = mention_score.unsqueeze(1) + torch.gather( | |||
mention_score.unsqueeze(0).expand(num_mention, num_mention), dim=1, index=antecedents) | |||
antecedent_scores += mention_scoreij | |||
antecedent_scores = torch.cat([torch.zeros([mention_score.shape[0], 1]).to(self.device), antecedent_scores], | |||
1) # [num_mentions, max_ant + 1] | |||
return antecedent_scores | |||
############################## | |||
def distance_bin(self, mention_distance): | |||
bins = torch.zeros(mention_distance.size()).byte().to(self.device) | |||
rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]] | |||
for t, k in enumerate(rg): | |||
i, j = k[0], k[1] | |||
b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) | |||
m1 = torch.ge(mention_distance, b) | |||
e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) | |||
m2 = torch.le(mention_distance, e) | |||
bins = bins + (t + 1) * (m1 & m2) | |||
return bins.long() | |||
def get_distance_emb(self, antecedents_tensor): | |||
num_mention = antecedents_tensor.shape[0] | |||
max_ant = antecedents_tensor.shape[1] | |||
assert max_ant <= self.config.max_antecedents | |||
source = torch.arange(0, num_mention).expand(max_ant, num_mention).transpose(0,1).to(self.device) # [num_mention,max_ant] | |||
mention_distance = source - antecedents_tensor | |||
mention_distance_bin = self.distance_bin(mention_distance) | |||
distance_emb = self.mention_distance_emb(mention_distance_bin) | |||
distance_emb = self.distance_drop(distance_emb) | |||
return distance_emb | |||
def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre): | |||
emb_dim = span_emb.shape[1] | |||
num_span = span_emb.shape[0] | |||
max_ant = antecedents.shape[1] | |||
assert span_emb.shape[0] == antecedents.shape[0] | |||
antecedents = torch.from_numpy(antecedents).to(self.device) | |||
# [num_span,max_ant,emb] | |||
antecedent_emb = torch.gather(span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim), dim=1, | |||
index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim)) | |||
# [num_span,max_ant,emb] | |||
target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim)) | |||
target_emb_tiled = target_emb_tiled.transpose(0, 1) | |||
similarity_emb = antecedent_emb * target_emb_tiled | |||
pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb] | |||
# get speakers and genre | |||
if self.config.use_metadata: | |||
antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(num_span, num_span).gather(dim=1, | |||
index=antecedents) | |||
same_speaker = torch.eq(mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant), | |||
antecedent_speaker_ids) # [num_mention,max_ant] | |||
speaker_embedding = self.speaker_emb(same_speaker.long().to(self.device)) # [mention_num.max_ant,emb] | |||
genre_embedding = self.genre_emb( | |||
torch.LongTensor([genre]).expand(num_span, max_ant).to(self.device)) # [mention_num,max_ant,emb] | |||
pair_emb_list.append(speaker_embedding) | |||
pair_emb_list.append(genre_embedding) | |||
# get distance emb | |||
if self.config.use_distance: | |||
distance_emb = self.get_distance_emb(antecedents) | |||
pair_emb_list.append(distance_emb) | |||
pair_emb = torch.cat(pair_emb_list, 2) | |||
return pair_emb | |||
def sequence_mask(self, len_list, max_len): | |||
x = np.zeros((len(len_list), max_len)) | |||
for i in range(len(len_list)): | |||
l = len_list[i] | |||
for j in range(l): | |||
x[i][j] = 1 | |||
return torch.from_numpy(x).float() | |||
def logsumexp(self, value, dim=None, keepdim=False): | |||
"""Numerically stable implementation of the operation | |||
value.exp().sum(dim, keepdim).log() | |||
""" | |||
# TODO: torch.max(value, dim=None) threw an error at time of writing | |||
if dim is not None: | |||
m, _ = torch.max(value, dim=dim, keepdim=True) | |||
value0 = value - m | |||
if keepdim is False: | |||
m = m.squeeze(dim) | |||
return m + torch.log(torch.sum(torch.exp(value0), | |||
dim=dim, keepdim=keepdim)) | |||
else: | |||
m = torch.max(value) | |||
sum_exp = torch.sum(torch.exp(value - m)) | |||
return m + torch.log(sum_exp) | |||
def softmax_loss(self, antecedent_scores, antecedent_labels): | |||
antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(self.device) | |||
gold_scores = antecedent_scores + torch.log(antecedent_labels.float()) # [num_mentions, max_ant + 1] | |||
marginalized_gold_scores = self.logsumexp(gold_scores, 1) # [num_mentions] | |||
log_norm = self.logsumexp(antecedent_scores, 1) # [num_mentions] | |||
return torch.sum(log_norm - marginalized_gold_scores) # [num_mentions]reduce_logsumexp | |||
def get_predicted_antecedents(self, antecedents, antecedent_scores): | |||
predicted_antecedents = [] | |||
for i, index in enumerate(np.argmax(antecedent_scores.detach(), axis=1) - 1): | |||
if index < 0: | |||
predicted_antecedents.append(-1) | |||
else: | |||
predicted_antecedents.append(antecedents[i, index]) | |||
return predicted_antecedents | |||
def get_predicted_clusters(self, mention_starts, mention_ends, predicted_antecedents): | |||
mention_to_predicted = {} | |||
predicted_clusters = [] | |||
for i, predicted_index in enumerate(predicted_antecedents): | |||
if predicted_index < 0: | |||
continue | |||
assert i > predicted_index | |||
predicted_antecedent = (int(mention_starts[predicted_index]), int(mention_ends[predicted_index])) | |||
if predicted_antecedent in mention_to_predicted: | |||
predicted_cluster = mention_to_predicted[predicted_antecedent] | |||
else: | |||
predicted_cluster = len(predicted_clusters) | |||
predicted_clusters.append([predicted_antecedent]) | |||
mention_to_predicted[predicted_antecedent] = predicted_cluster | |||
mention = (int(mention_starts[i]), int(mention_ends[i])) | |||
predicted_clusters[predicted_cluster].append(mention) | |||
mention_to_predicted[mention] = predicted_cluster | |||
predicted_clusters = [tuple(pc) for pc in predicted_clusters] | |||
mention_to_predicted = {m: predicted_clusters[i] for m, i in mention_to_predicted.items()} | |||
return predicted_clusters, mention_to_predicted | |||
def evaluate_coref(self, mention_starts, mention_ends, predicted_antecedents, gold_clusters, evaluator): | |||
gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters] | |||
mention_to_gold = {} | |||
for gc in gold_clusters: | |||
for mention in gc: | |||
mention_to_gold[mention] = gc | |||
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(mention_starts, mention_ends, | |||
predicted_antecedents) | |||
evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) | |||
return predicted_clusters | |||
def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): | |||
""" | |||
实际输入都是tensor | |||
:param sentences: 句子,被fastNLP转化成了numpy, | |||
:param doc_np: 被fastNLP转化成了Tensor | |||
:param speaker_ids_np: 被fastNLP转化成了Tensor | |||
:param genre: 被fastNLP转化成了Tensor | |||
:param char_index: 被fastNLP转化成了Tensor | |||
:param seq_len: 被fastNLP转化成了Tensor | |||
:return: | |||
""" | |||
# change for fastNLP | |||
sentences = sentences[0].tolist() | |||
doc_tensor = doc_np[0] | |||
speakers_tensor = speaker_ids_np[0] | |||
genre = genre[0].item() | |||
char_index = char_index[0] | |||
seq_len = seq_len[0].cpu().numpy() | |||
# 类型 | |||
# doc_tensor = torch.from_numpy(doc_np).to(self.device) | |||
# speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device) | |||
mention_emb_list = [] | |||
word_emb = self.emb(doc_tensor) | |||
word_emb_list = [word_emb] | |||
if self.config.use_CNN: | |||
# [batch, length, char_length, char_dim] | |||
char = self.char_emb(char_index) | |||
char_size = char.size() | |||
# first transform to [batch *length, char_length, char_dim] | |||
# then transpose to [batch * length, char_dim, char_length] | |||
char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2) | |||
# put into cnn [batch*length, char_filters, char_length] | |||
# then put into maxpooling [batch * length, char_filters] | |||
char_over_cnn, _ = self.conv1(char).max(dim=2) | |||
# reshape to [batch, length, char_filters] | |||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||
word_emb_list.append(char_over_cnn) | |||
char_over_cnn, _ = self.conv2(char).max(dim=2) | |||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||
word_emb_list.append(char_over_cnn) | |||
char_over_cnn, _ = self.conv3(char).max(dim=2) | |||
char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) | |||
word_emb_list.append(char_over_cnn) | |||
# word_emb = torch.cat(word_emb_list, dim=2) | |||
# use elmo or not | |||
if self.config.use_elmo: | |||
# 如果确实被截断了 | |||
if doc_tensor.shape[0] == 50 and len(sentences) > 50: | |||
sentences = sentences[0:50] | |||
elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(sentences) | |||
elmo_embedding = elmo_embedding.to( | |||
self.device) # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024] | |||
elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \ | |||
self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2] | |||
word_emb_list.append(elmo_embedding) | |||
# print(word_emb_list[0].shape) | |||
# print(word_emb_list[1].shape) | |||
# print(word_emb_list[2].shape) | |||
# print(word_emb_list[3].shape) | |||
# print(word_emb_list[4].shape) | |||
word_emb = torch.cat(word_emb_list, dim=2) | |||
word_emb = self.emb_dropout(word_emb) | |||
# word_emb_elmo = self.emb_dropout(word_emb_elmo) | |||
lstm_out = self._reorder_lstm(word_emb, seq_len) | |||
flatten_lstm = self.flat_lstm(lstm_out, seq_len) # [word_num,emb] | |||
flatten_lstm = self.bilstm_drop(flatten_lstm) | |||
# TODO 没有按照论文写 | |||
flatten_word_emb = self.flat_lstm(word_emb, seq_len) # [word_num,emb] | |||
mention_start, mention_end = self.get_mention_start_end(seq_len) # [mention_num] | |||
self.mention_start_np = mention_start # [mention_num] np | |||
self.mention_end_np = mention_end | |||
mention_num = mention_start.shape[0] | |||
emb_start, emb_end = self.get_mention_emb(flatten_lstm, mention_start, mention_end) # [mention_num,emb] | |||
# list | |||
mention_emb_list.append(emb_start) | |||
mention_emb_list.append(emb_end) | |||
if self.config.use_width: | |||
mention_width_index = mention_end - mention_start | |||
mention_width_tensor = torch.from_numpy(mention_width_index).to(self.device) # [mention_num] | |||
mention_width_emb = self.feature_emb(mention_width_tensor) | |||
mention_width_emb = self.feature_emb_dropout(mention_width_emb) | |||
mention_emb_list.append(mention_width_emb) | |||
if self.config.model_heads: | |||
mention_index = self.get_mention_index(mention_start, self.config.span_width) # [mention_num,max_mention] | |||
log_mask_tensor = self.get_mask(mention_start, mention_end).float().to( | |||
self.device) # [mention_num,max_mention] | |||
alpha = self.atten(flatten_lstm).to(self.device) # [word_num] | |||
# 得到attention | |||
mention_head_score = torch.gather(alpha.expand(mention_num, -1), 1, | |||
mention_index).float().to(self.device) # [mention_num,max_mention] | |||
mention_attention = F.softmax(mention_head_score + log_mask_tensor, dim=1) # [mention_num,max_mention] | |||
# TODO flatte lstm | |||
word_num = flatten_lstm.shape[0] | |||
lstm_emb = flatten_lstm.shape[1] | |||
emb_num = flatten_word_emb.shape[1] | |||
# [num_mentions, max_mention_width, emb] | |||
mention_text_emb = torch.gather( | |||
flatten_word_emb.unsqueeze(1).expand(word_num, self.config.span_width, emb_num), | |||
0, mention_index.unsqueeze(2).expand(mention_num, self.config.span_width, | |||
emb_num)) | |||
# [mention_num,emb] | |||
mention_head_emb = torch.sum( | |||
mention_attention.unsqueeze(2).expand(mention_num, self.config.span_width, emb_num) * mention_text_emb, | |||
dim=1) | |||
mention_emb_list.append(mention_head_emb) | |||
candidate_mention_emb = torch.cat(mention_emb_list, 1) # [candidate_mention_num,emb] | |||
candidate_mention_score = self.mention_score(candidate_mention_emb) # [candidate_mention_num] | |||
antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (None, None, None, None) | |||
mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \ | |||
self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len) | |||
mention_speakers_ids = speakers_tensor.index_select(dim=0, index=mention_start_tensor) # num_mention | |||
antecedents, antecedents_len = self.get_antecedents(mention_start_tensor, self.config.max_antecedents) | |||
antecedent_scores = self.get_antecedents_score(mention_emb, mention_score, antecedents, antecedents_len, | |||
mention_speakers_ids, genre) | |||
ans = {"candidate_mention_score": candidate_mention_score, "antecedent_scores": antecedent_scores, | |||
"antecedents": antecedents, "mention_start_tensor": mention_start_tensor, | |||
"mention_end_tensor": mention_end_tensor} | |||
return ans | |||
def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): | |||
ans = self(sentences, | |||
doc_np, | |||
speaker_ids_np, | |||
genre, | |||
char_index, | |||
seq_len) | |||
predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"]) | |||
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"], | |||
ans["mention_end_tensor"], | |||
predicted_antecedents) | |||
return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} | |||
if __name__ == '__main__': | |||
pass |
@@ -0,0 +1,225 @@ | |||
import json | |||
import numpy as np | |||
from . import util | |||
import collections | |||
def load(path): | |||
""" | |||
load the file from jsonline | |||
:param path: | |||
:return: examples with many example(dict): {"clusters":[[[mention],[mention]],[another cluster]], | |||
"doc_key":"str","speakers":[[,,,],[]...],"sentence":[[][]]} | |||
""" | |||
with open(path) as f: | |||
train_examples = [json.loads(jsonline) for jsonline in f.readlines()] | |||
return train_examples | |||
def get_vocab(): | |||
""" | |||
从所有的句子中得到最终的字典,被main调用,不止是train,还有dev和test | |||
:param examples: | |||
:return: word2id & id2word | |||
""" | |||
word2id = {'PAD':0,'UNK':1} | |||
id2word = {0:'PAD',1:'UNK'} | |||
index = 2 | |||
data = [load("../data/train.english.jsonlines"),load("../data/dev.english.jsonlines"),load("../data/test.english.jsonlines")] | |||
for examples in data: | |||
for example in examples: | |||
for sent in example["sentences"]: | |||
for word in sent: | |||
if(word not in word2id): | |||
word2id[word]=index | |||
id2word[index] = word | |||
index += 1 | |||
return word2id,id2word | |||
def normalize(v): | |||
norm = np.linalg.norm(v) | |||
if norm > 0: | |||
return v / norm | |||
else: | |||
return v | |||
# 加载glove得到embedding | |||
def get_emb(id2word,embedding_size): | |||
glove_oov = 0 | |||
turian_oov = 0 | |||
both = 0 | |||
glove_emb_path = "../data/glove.840B.300d.txt.filtered" | |||
turian_emb_path = "../data/turian.50d.txt" | |||
word_num = len(id2word) | |||
emb = np.zeros((word_num,embedding_size)) | |||
glove_emb_dict = util.load_embedding_dict(glove_emb_path,300,"txt") | |||
turian_emb_dict = util.load_embedding_dict(turian_emb_path,50,"txt") | |||
for i in range(word_num): | |||
if id2word[i] in glove_emb_dict: | |||
word_embedding = glove_emb_dict.get(id2word[i]) | |||
emb[i][0:300] = np.array(word_embedding) | |||
else: | |||
# print(id2word[i]) | |||
glove_oov += 1 | |||
if id2word[i] in turian_emb_dict: | |||
word_embedding = turian_emb_dict.get(id2word[i]) | |||
emb[i][300:350] = np.array(word_embedding) | |||
else: | |||
# print(id2word[i]) | |||
turian_oov += 1 | |||
if id2word[i] not in glove_emb_dict and id2word[i] not in turian_emb_dict: | |||
both += 1 | |||
emb[i] = normalize(emb[i]) | |||
print("embedding num:"+str(word_num)) | |||
print("glove num:"+str(glove_oov)) | |||
print("glove oov rate:"+str(glove_oov/word_num)) | |||
print("turian num:"+str(turian_oov)) | |||
print("turian oov rate:"+str(turian_oov/word_num)) | |||
print("both num:"+str(both)) | |||
return emb | |||
def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): | |||
max_len = 0 | |||
max_word_length = 0 | |||
docvex = [] | |||
length = [] | |||
if is_train: | |||
sent_num = min(max_sentences,len(doc)) | |||
else: | |||
sent_num = len(doc) | |||
for i in range(sent_num): | |||
sent = doc[i] | |||
length.append(len(sent)) | |||
if (len(sent) > max_len): | |||
max_len = len(sent) | |||
sent_vec =[] | |||
for j,word in enumerate(sent): | |||
if len(word)>max_word_length: | |||
max_word_length = len(word) | |||
if word in word2id: | |||
sent_vec.append(word2id[word]) | |||
else: | |||
sent_vec.append(word2id["UNK"]) | |||
docvex.append(sent_vec) | |||
char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) | |||
for i in range(sent_num): | |||
sent = doc[i] | |||
for j,word in enumerate(sent): | |||
char_index[i, j, :len(word)] = [char_dict[c] for c in word] | |||
return docvex,char_index,length,max_len | |||
# TODO 修改了接口,确认所有该修改的地方都修改好 | |||
def doc2numpy(doc,word2id,chardict,max_filter,max_sentences,is_train): | |||
docvec, char_index, length, max_len = _doc2vec(doc,word2id,chardict,max_filter,max_sentences,is_train) | |||
assert max(length) == max_len | |||
assert char_index.shape[0]==len(length) | |||
assert char_index.shape[1]==max_len | |||
doc_np = np.zeros((len(docvec), max_len), int) | |||
for i in range(len(docvec)): | |||
for j in range(len(docvec[i])): | |||
doc_np[i][j] = docvec[i][j] | |||
return doc_np,char_index,length | |||
# TODO 没有测试 | |||
def speaker2numpy(speakers_raw,max_sentences,is_train): | |||
if is_train and len(speakers_raw)> max_sentences: | |||
speakers_raw = speakers_raw[0:max_sentences] | |||
speakers = flatten(speakers_raw) | |||
speaker_dict = {s: i for i, s in enumerate(set(speakers))} | |||
speaker_ids = np.array([speaker_dict[s] for s in speakers]) | |||
return speaker_ids | |||
def flat_cluster(clusters): | |||
flatted = [] | |||
for cluster in clusters: | |||
for item in cluster: | |||
flatted.append(item) | |||
return flatted | |||
def get_right_mention(clusters,mention_start_np,mention_end_np): | |||
flatted = flat_cluster(clusters) | |||
cluster_num = len(flatted) | |||
mention_num = mention_start_np.shape[0] | |||
right_mention = np.zeros(mention_num,dtype=int) | |||
for i in range(mention_num): | |||
if [mention_start_np[i],mention_end_np[i]] in flatted: | |||
right_mention[i]=1 | |||
return right_mention,cluster_num | |||
def handle_cluster(clusters): | |||
gold_mentions = sorted(tuple(m) for m in flatten(clusters)) | |||
gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} | |||
cluster_ids = np.zeros(len(gold_mentions), dtype=int) | |||
for cluster_id, cluster in enumerate(clusters): | |||
for mention in cluster: | |||
cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id | |||
gold_starts, gold_ends = tensorize_mentions(gold_mentions) | |||
return cluster_ids, gold_starts, gold_ends | |||
# 展平 | |||
def flatten(l): | |||
return [item for sublist in l for item in sublist] | |||
# 把mention分成start end | |||
def tensorize_mentions(mentions): | |||
if len(mentions) > 0: | |||
starts, ends = zip(*mentions) | |||
else: | |||
starts, ends = [], [] | |||
return np.array(starts), np.array(ends) | |||
def get_char_dict(path): | |||
vocab = ["<UNK>"] | |||
with open(path) as f: | |||
vocab.extend(c.strip() for c in f.readlines()) | |||
char_dict = collections.defaultdict(int) | |||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||
return char_dict | |||
def get_labels(clusters,mention_starts,mention_ends,max_antecedents): | |||
cluster_ids, gold_starts, gold_ends = handle_cluster(clusters) | |||
num_mention = mention_starts.shape[0] | |||
num_gold = gold_starts.shape[0] | |||
max_antecedents = min(max_antecedents, num_mention) | |||
mention_indices = {} | |||
for i in range(num_mention): | |||
mention_indices[(mention_starts[i].detach().item(), mention_ends[i].detach().item())] = i | |||
# 用来记录哪些mention是对的,-1表示错误,正数代表这个mention实际上对应哪个gold cluster的id | |||
mention_cluster_ids = [-1] * num_mention | |||
# test | |||
right_mention_count = 0 | |||
for i in range(num_gold): | |||
right_mention = mention_indices.get((gold_starts[i], gold_ends[i])) | |||
if (right_mention != None): | |||
right_mention_count += 1 | |||
mention_cluster_ids[right_mention] = cluster_ids[i] | |||
# i j 是否属于同一个cluster | |||
labels = np.zeros((num_mention, max_antecedents + 1), dtype=bool) # [num_mention,max_an+1] | |||
for i in range(num_mention): | |||
ante_count = 0 | |||
null_label = True | |||
for j in range(max(0, i - max_antecedents), i): | |||
if (mention_cluster_ids[i] >= 0 and mention_cluster_ids[i] == mention_cluster_ids[j]): | |||
labels[i, ante_count + 1] = True | |||
null_label = False | |||
else: | |||
labels[i, ante_count + 1] = False | |||
ante_count += 1 | |||
for j in range(ante_count, max_antecedents): | |||
labels[i, j + 1] = False | |||
labels[i, 0] = null_label | |||
return labels | |||
# test=========================== | |||
if __name__=="__main__": | |||
word2id,id2word = get_vocab() | |||
get_emb(id2word,350) | |||
@@ -0,0 +1,32 @@ | |||
from fastNLP.core.losses import LossBase | |||
from reproduction.coreference_resolution.model.preprocess import get_labels | |||
from reproduction.coreference_resolution.model.config import Config | |||
import torch | |||
class SoftmaxLoss(LossBase): | |||
""" | |||
交叉熵loss | |||
允许多标签分类 | |||
""" | |||
def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None): | |||
""" | |||
:param pred: | |||
:param target: | |||
""" | |||
super().__init__() | |||
self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters, | |||
mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) | |||
def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor): | |||
antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor, | |||
Config().max_antecedents) | |||
antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) | |||
gold_scores = antecedent_scores + torch.log(antecedent_labels.float()).to(torch.device("cuda:" + Config().cuda)) # [num_mentions, max_ant + 1] | |||
marginalized_gold_scores = gold_scores.logsumexp(dim=1) # [num_mentions] | |||
log_norm = antecedent_scores.logsumexp(dim=1) # [num_mentions] | |||
return torch.sum(log_norm - marginalized_gold_scores) |
@@ -0,0 +1,101 @@ | |||
import os | |||
import errno | |||
import collections | |||
import torch | |||
import numpy as np | |||
import pyhocon | |||
# flatten the list | |||
def flatten(l): | |||
return [item for sublist in l for item in sublist] | |||
def get_config(filename): | |||
return pyhocon.ConfigFactory.parse_file(filename) | |||
# safe make directions | |||
def mkdirs(path): | |||
try: | |||
os.makedirs(path) | |||
except OSError as exception: | |||
if exception.errno != errno.EEXIST: | |||
raise | |||
return path | |||
def load_char_dict(char_vocab_path): | |||
vocab = ["<unk>"] | |||
with open(char_vocab_path) as f: | |||
vocab.extend(c.strip() for c in f.readlines()) | |||
char_dict = collections.defaultdict(int) | |||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||
return char_dict | |||
# 加载embedding | |||
def load_embedding_dict(embedding_path, embedding_size, embedding_format): | |||
print("Loading word embeddings from {}...".format(embedding_path)) | |||
default_embedding = np.zeros(embedding_size) | |||
embedding_dict = collections.defaultdict(lambda: default_embedding) | |||
skip_first = embedding_format == "vec" | |||
with open(embedding_path) as f: | |||
for i, line in enumerate(f.readlines()): | |||
if skip_first and i == 0: | |||
continue | |||
splits = line.split() | |||
assert len(splits) == embedding_size + 1 | |||
word = splits[0] | |||
embedding = np.array([float(s) for s in splits[1:]]) | |||
embedding_dict[word] = embedding | |||
print("Done loading word embeddings.") | |||
return embedding_dict | |||
# safe devide | |||
def maybe_divide(x, y): | |||
return 0 if y == 0 else x / float(y) | |||
def shape(x, dim): | |||
return x.get_shape()[dim].value or torch.shape(x)[dim] | |||
def normalize(v): | |||
norm = np.linalg.norm(v) | |||
if norm > 0: | |||
return v / norm | |||
else: | |||
return v | |||
class RetrievalEvaluator(object): | |||
def __init__(self): | |||
self._num_correct = 0 | |||
self._num_gold = 0 | |||
self._num_predicted = 0 | |||
def update(self, gold_set, predicted_set): | |||
self._num_correct += len(gold_set & predicted_set) | |||
self._num_gold += len(gold_set) | |||
self._num_predicted += len(predicted_set) | |||
def recall(self): | |||
return maybe_divide(self._num_correct, self._num_gold) | |||
def precision(self): | |||
return maybe_divide(self._num_correct, self._num_predicted) | |||
def metrics(self): | |||
recall = self.recall() | |||
precision = self.precision() | |||
f1 = maybe_divide(2 * recall * precision, precision + recall) | |||
return recall, precision, f1 | |||
if __name__=="__main__": | |||
print(load_char_dict("../data/char_vocab.english.txt")) | |||
embedding_dict = load_embedding_dict("../data/glove.840B.300d.txt.filtered",300,"txt") | |||
print("hello") |
@@ -0,0 +1,49 @@ | |||
# 共指消解复现 | |||
## 介绍 | |||
Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | |||
对于涉及自然语言理解的许多更高级别的NLP任务来说, | |||
这是一个重要的步骤,例如文档摘要,问题回答和信息提取。 | |||
代码的实现主要基于[ End-to-End Coreference Resolution (Lee et al, 2017)](https://arxiv.org/pdf/1707.07045). | |||
## 数据获取与预处理 | |||
论文在[OntoNote5.0](https://allennlp.org/models)数据集上取得了当时的sota结果。 | |||
由于版权问题,本文无法提供数据集的下载,请自行下载。 | |||
原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | |||
代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||
处理之后的数据集为json格式,例子: | |||
``` | |||
{ | |||
"clusters": [], | |||
"doc_key": "nw", | |||
"sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]], | |||
"speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]] | |||
} | |||
``` | |||
### embedding 数据集下载 | |||
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | |||
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||
## 运行 | |||
```python | |||
# 训练代码 | |||
CUDA_VISIBLE_DEVICES=0 python train.py | |||
# 测试代码 | |||
CUDA_VISIBLE_DEVICES=0 python valid.py | |||
``` | |||
## 结果 | |||
原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | |||
其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||
在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||
## 问题 | |||
如果您有什么问题或者反馈,请提issue或者邮件联系我: | |||
yexu_i@qq.com |
@@ -0,0 +1,14 @@ | |||
import unittest | |||
from ..data_load.cr_loader import CRLoader | |||
class Test_CRLoader(unittest.TestCase): | |||
def test_cr_loader(self): | |||
train_path = 'data/train.english.jsonlines.mini' | |||
dev_path = 'data/dev.english.jsonlines.minid' | |||
test_path = 'data/test.english.jsonlines' | |||
cr = CRLoader() | |||
data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) | |||
print(data_info.datasets['train'][0]) | |||
print(data_info.datasets['dev'][0]) | |||
print(data_info.datasets['test'][0]) |
@@ -0,0 +1,69 @@ | |||
import sys | |||
sys.path.append('../..') | |||
import torch | |||
from torch.optim import Adam | |||
from fastNLP.core.callback import Callback, GradientClipCallback | |||
from fastNLP.core.trainer import Trainer | |||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||
from reproduction.coreference_resolution.model.config import Config | |||
from reproduction.coreference_resolution.model.model_re import Model | |||
from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | |||
from reproduction.coreference_resolution.model.metric import CRMetric | |||
from fastNLP import SequentialSampler | |||
from fastNLP import cache_results | |||
# torch.backends.cudnn.benchmark = False | |||
# torch.backends.cudnn.deterministic = True | |||
class LRCallback(Callback): | |||
def __init__(self, parameters, decay_rate=1e-3): | |||
super().__init__() | |||
self.paras = parameters | |||
self.decay_rate = decay_rate | |||
def on_step_end(self): | |||
if self.step % 100 == 0: | |||
for para in self.paras: | |||
para['lr'] = para['lr'] * (1 - self.decay_rate) | |||
if __name__ == "__main__": | |||
config = Config() | |||
print(config) | |||
@cache_results('cache.pkl') | |||
def cache(): | |||
cr_train_dev_test = CRLoader() | |||
data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, | |||
'test': config.test_path}) | |||
return data_info | |||
data_info = cache() | |||
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | |||
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | |||
# print(data_info) | |||
model = Model(data_info.vocabs, config) | |||
print(model) | |||
loss = SoftmaxLoss() | |||
metric = CRMetric() | |||
optim = Adam(model.parameters(), lr=config.lr) | |||
lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | |||
trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], | |||
loss=loss, metrics=metric, check_code_level=-1,sampler=None, | |||
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | |||
optimizer=optim, | |||
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', | |||
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) | |||
print() | |||
trainer.train() |
@@ -0,0 +1,24 @@ | |||
import torch | |||
from reproduction.coreference_resolution.model.config import Config | |||
from reproduction.coreference_resolution.model.metric import CRMetric | |||
from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||
from fastNLP import Tester | |||
import argparse | |||
if __name__=='__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--path') | |||
args = parser.parse_args() | |||
cr_loader = CRLoader() | |||
config = Config() | |||
data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, | |||
'test': config.test_path}) | |||
metirc = CRMetric() | |||
model = torch.load(args.path) | |||
tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||
tester.test() | |||
print('test over') | |||
@@ -0,0 +1,284 @@ | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.dataset_loader import ConllLoader | |||
import numpy as np | |||
from itertools import chain | |||
from fastNLP import DataSet, Vocabulary | |||
from functools import partial | |||
import os | |||
from typing import Union, Dict | |||
from reproduction.utils import check_dataloader_paths | |||
class CTBxJointLoader(DataSetLoader): | |||
""" | |||
文件夹下应该具有以下的文件结构 | |||
-train.conllx | |||
-dev.conllx | |||
-test.conllx | |||
每个文件中的内容如下(空格隔开不同的句子, 共有) | |||
1 费孝通 _ NR NR _ 3 nsubjpass _ _ | |||
2 被 _ SB SB _ 3 pass _ _ | |||
3 授予 _ VV VV _ 0 root _ _ | |||
4 麦格赛赛 _ NR NR _ 5 nn _ _ | |||
5 奖 _ NN NN _ 3 dobj _ _ | |||
1 新华社 _ NR NR _ 7 dep _ _ | |||
2 马尼拉 _ NR NR _ 7 dep _ _ | |||
3 8月 _ NT NT _ 7 dep _ _ | |||
4 31日 _ NT NT _ 7 dep _ _ | |||
... | |||
""" | |||
def __init__(self): | |||
self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) | |||
def load(self, path:str): | |||
""" | |||
给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 | |||
words: list[str] | |||
pos_tags: list[str] | |||
heads: list[int] | |||
labels: list[str] | |||
:param path: | |||
:return: | |||
""" | |||
dataset = self._loader.load(path) | |||
dataset.heads.int() | |||
return dataset | |||
def process(self, paths): | |||
""" | |||
:param paths: | |||
:return: | |||
Dataset包含以下的field | |||
chars: | |||
bigrams: | |||
trigrams: | |||
pre_chars: | |||
pre_bigrams: | |||
pre_trigrams: | |||
seg_targets: | |||
seg_masks: | |||
seq_lens: | |||
char_labels: | |||
char_heads: | |||
gold_word_pairs: | |||
seg_targets: | |||
seg_masks: | |||
char_labels: | |||
char_heads: | |||
pun_masks: | |||
gold_label_word_pairs: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
data.datasets[name] = dataset | |||
char_labels_vocab = Vocabulary(padding=None, unknown=None) | |||
def process(dataset, char_label_vocab): | |||
dataset.apply(add_word_lst, new_field_name='word_lst') | |||
dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') | |||
dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') | |||
dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') | |||
dataset.apply(add_char_heads, new_field_name='char_heads') | |||
dataset.apply(add_char_labels, new_field_name='char_labels') | |||
dataset.apply(add_segs, new_field_name='seg_targets') | |||
dataset.apply(add_mask, new_field_name='seg_masks') | |||
dataset.add_seq_len('chars', new_field_name='seq_lens') | |||
dataset.apply(add_pun_masks, new_field_name='pun_masks') | |||
if len(char_label_vocab.word_count)==0: | |||
char_label_vocab.from_dataset(dataset, field_name='char_labels') | |||
char_label_vocab.index_dataset(dataset, field_name='char_labels') | |||
new_dataset = add_root(dataset) | |||
new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) | |||
global add_label_word_pairs | |||
add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) | |||
new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) | |||
new_dataset.set_pad_val('char_labels', -1) | |||
new_dataset.set_pad_val('char_heads', -1) | |||
return new_dataset | |||
for name in list(paths.keys()): | |||
dataset = data.datasets[name] | |||
dataset = process(dataset, char_labels_vocab) | |||
data.datasets[name] = dataset | |||
data.vocabs['char_labels'] = char_labels_vocab | |||
char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') | |||
bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') | |||
trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') | |||
for name in ['chars', 'bigrams', 'trigrams']: | |||
vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) | |||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) | |||
data.vocabs['pre_{}'.format(name)] = vocab | |||
for name, vocab in zip(['chars', 'bigrams', 'trigrams'], | |||
[char_vocab, bigram_vocab, trigram_vocab]): | |||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) | |||
data.vocabs[name] = vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', | |||
'pre_bigrams', 'pre_trigrams') | |||
dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', | |||
'char_heads', | |||
'pun_masks', 'gold_label_word_pairs') | |||
return data | |||
def add_label_word_pairs(instance, label_vocab): | |||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
word_end_indexes.insert(0, 0) | |||
word_pairs = [] | |||
labels = instance['labels'] | |||
pos_tags = instance['pos_tags'] | |||
for idx, head in enumerate(instance['heads']): | |||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
continue | |||
label = label_vocab.to_index(labels[idx]) | |||
if head==0: | |||
word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
else: | |||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, | |||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
return word_pairs | |||
def add_word_pairs(instance): | |||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
word_end_indexes.insert(0, 0) | |||
word_pairs = [] | |||
pos_tags = instance['pos_tags'] | |||
for idx, head in enumerate(instance['heads']): | |||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
continue | |||
if head==0: | |||
word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
else: | |||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), | |||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
return word_pairs | |||
def add_root(dataset): | |||
new_dataset = DataSet() | |||
for sample in dataset: | |||
chars = ['char_root'] + sample['chars'] | |||
bigrams = ['bigram_root'] + sample['bigrams'] | |||
trigrams = ['trigram_root'] + sample['trigrams'] | |||
seq_lens = sample['seq_lens']+1 | |||
char_labels = [0] + sample['char_labels'] | |||
char_heads = [0] + sample['char_heads'] | |||
sample['chars'] = chars | |||
sample['bigrams'] = bigrams | |||
sample['trigrams'] = trigrams | |||
sample['seq_lens'] = seq_lens | |||
sample['char_labels'] = char_labels | |||
sample['char_heads'] = char_heads | |||
new_dataset.append(sample) | |||
return new_dataset | |||
def add_pun_masks(instance): | |||
tags = instance['pos_tags'] | |||
pun_masks = [] | |||
for word, tag in zip(instance['words'], tags): | |||
if tag=='PU': | |||
pun_masks.extend([1]*len(word)) | |||
else: | |||
pun_masks.extend([0]*len(word)) | |||
return pun_masks | |||
def add_word_lst(instance): | |||
words = instance['words'] | |||
word_lst = [list(word) for word in words] | |||
return word_lst | |||
def add_bigram(instance): | |||
chars = instance['chars'] | |||
length = len(chars) | |||
chars = chars + ['<eos>'] | |||
bigrams = [] | |||
for i in range(length): | |||
bigrams.append(''.join(chars[i:i + 2])) | |||
return bigrams | |||
def add_trigram(instance): | |||
chars = instance['chars'] | |||
length = len(chars) | |||
chars = chars + ['<eos>'] * 2 | |||
trigrams = [] | |||
for i in range(length): | |||
trigrams.append(''.join(chars[i:i + 3])) | |||
return trigrams | |||
def add_char_heads(instance): | |||
words = instance['word_lst'] | |||
heads = instance['heads'] | |||
char_heads = [] | |||
char_index = 1 # 因此存在root节点所以需要从1开始 | |||
head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 | |||
for word, head in zip(words, heads): | |||
char_head = [] | |||
if len(word)>1: | |||
char_head.append(char_index+1) | |||
char_index += 1 | |||
for _ in range(len(word)-2): | |||
char_index += 1 | |||
char_head.append(char_index) | |||
char_index += 1 | |||
char_head.append(head_end_indexes[head-1]) | |||
char_heads.extend(char_head) | |||
return char_heads | |||
def add_char_labels(instance): | |||
""" | |||
将word_lst中的数据按照下面的方式设置label | |||
比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. | |||
对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label | |||
:param instance: | |||
:return: | |||
""" | |||
words = instance['word_lst'] | |||
labels = instance['labels'] | |||
char_labels = [] | |||
for word, label in zip(words, labels): | |||
for _ in range(len(word)-1): | |||
char_labels.append('APP') | |||
char_labels.append(label) | |||
return char_labels | |||
# add seg_targets | |||
def add_segs(instance): | |||
words = instance['word_lst'] | |||
segs = [0]*len(instance['chars']) | |||
index = 0 | |||
for word in words: | |||
index = index + len(word) - 1 | |||
segs[index] = len(word)-1 | |||
index = index + 1 | |||
return segs | |||
# add target_masks | |||
def add_mask(instance): | |||
words = instance['word_lst'] | |||
mask = [] | |||
for word in words: | |||
mask.extend([0] * (len(word) - 1)) | |||
mask.append(1) | |||
return mask |
@@ -0,0 +1,311 @@ | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from fastNLP.modules.dropout import TimestepDropout | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from fastNLP import seq_len_to_mask | |||
from fastNLP.modules import Embedding | |||
def drop_input_independent(word_embeddings, dropout_emb): | |||
batch_size, seq_length, _ = word_embeddings.size() | |||
word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) | |||
word_masks = torch.bernoulli(word_masks) | |||
word_masks = word_masks.unsqueeze(dim=2) | |||
word_embeddings = word_embeddings * word_masks | |||
return word_embeddings | |||
class CharBiaffineParser(BiaffineParser): | |||
def __init__(self, char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers=3, | |||
rnn_hidden_size=800, #单向的数量 | |||
arc_mlp_size=500, | |||
label_mlp_size=100, | |||
dropout=0.3, | |||
encoder='lstm', | |||
use_greedy_infer=False, | |||
app_index = 0, | |||
pre_chars_embed=None, | |||
pre_bigrams_embed=None, | |||
pre_trigrams_embed=None): | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.char_embed = Embedding((char_vocab_size, emb_dim)) | |||
self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) | |||
self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) | |||
if pre_chars_embed: | |||
self.pre_char_embed = Embedding(pre_chars_embed) | |||
self.pre_char_embed.requires_grad = False | |||
if pre_bigrams_embed: | |||
self.pre_bigram_embed = Embedding(pre_bigrams_embed) | |||
self.pre_bigram_embed.requires_grad = False | |||
if pre_trigrams_embed: | |||
self.pre_trigram_embed = Embedding(pre_trigrams_embed) | |||
self.pre_trigram_embed.requires_grad = False | |||
self.timestep_drop = TimestepDropout(dropout) | |||
self.encoder_name = encoder | |||
if encoder == 'var-lstm': | |||
self.encoder = VarLSTM(input_size=emb_dim*3, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
input_dropout=dropout, | |||
hidden_dropout=dropout, | |||
bidirectional=True) | |||
elif encoder == 'lstm': | |||
self.encoder = nn.LSTM(input_size=emb_dim*3, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
dropout=dropout, | |||
bidirectional=True) | |||
else: | |||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||
nn.LeakyReLU(0.1), | |||
TimestepDropout(p=dropout),) | |||
self.arc_mlp_size = arc_mlp_size | |||
self.label_mlp_size = label_mlp_size | |||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
self.use_greedy_infer = use_greedy_infer | |||
self.reset_parameters() | |||
self.dropout = dropout | |||
self.app_index = app_index | |||
self.num_label = num_label | |||
if self.app_index != 0: | |||
raise ValueError("现在app_index必须等于0") | |||
def reset_parameters(self): | |||
for name, m in self.named_modules(): | |||
if 'embed' in name: | |||
pass | |||
elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): | |||
pass | |||
else: | |||
for p in m.parameters(): | |||
if len(p.size())>1: | |||
nn.init.xavier_normal_(p, gain=0.1) | |||
else: | |||
nn.init.uniform_(p, -0.1, 0.1) | |||
def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, | |||
pre_trigrams=None): | |||
""" | |||
max_len是包含root的 | |||
:param chars: batch_size x max_len | |||
:param ngrams: batch_size x max_len*ngram_per_char | |||
:param seq_lens: batch_size | |||
:param gold_heads: batch_size x max_len | |||
:param pre_chars: batch_size x max_len | |||
:param pre_ngrams: batch_size x max_len*ngram_per_char | |||
:return dict: parsing results | |||
arc_pred: [batch_size, seq_len, seq_len] | |||
label_pred: [batch_size, seq_len, seq_len] | |||
mask: [batch_size, seq_len] | |||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||
""" | |||
# prepare embeddings | |||
batch_size, seq_len = chars.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
mask = seq_len_to_mask(seq_lens).long() | |||
chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] | |||
bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] | |||
trigrams = self.trigram_embed(trigrams) | |||
if pre_chars is not None: | |||
pre_chars = self.pre_char_embed(pre_chars) | |||
# pre_chars = self.pre_char_fc(pre_chars) | |||
chars = pre_chars + chars | |||
if pre_bigrams is not None: | |||
pre_bigrams = self.pre_bigram_embed(pre_bigrams) | |||
# pre_bigrams = self.pre_bigram_fc(pre_bigrams) | |||
bigrams = bigrams + pre_bigrams | |||
if pre_trigrams is not None: | |||
pre_trigrams = self.pre_trigram_embed(pre_trigrams) | |||
# pre_trigrams = self.pre_trigram_fc(pre_trigrams) | |||
trigrams = trigrams + pre_trigrams | |||
x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] | |||
# encoder, extract features | |||
if self.training: | |||
x = drop_input_independent(x, self.dropout) | |||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||
x = x[sort_idx] | |||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||
feat, _ = self.encoder(x) # -> [N,L,C] | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
feat = feat[unsort_idx] | |||
feat = self.timestep_drop(feat) | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
feat = self.mlp(feat) | |||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
else: | |||
heads = self.mst_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
if gold_heads is None: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
# heads: batch_size x max_len | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] | |||
# 这里限制一下,只有当head为下一个时,才能预测app这个label | |||
arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ | |||
.repeat(batch_size, 1) # batch_size x max_len | |||
app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app | |||
app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) | |||
app_masks[:, :, 1:] = 0 | |||
label_pred = label_pred.masked_fill(app_masks, -np.inf) | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||
if head_pred is not None: | |||
res_dict['head_pred'] = head_pred | |||
return res_dict | |||
@staticmethod | |||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||
""" | |||
Compute loss. | |||
:param arc_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, n_tags] | |||
:param arc_true: [batch_size, seq_len] | |||
:param label_true: [batch_size, seq_len] | |||
:param mask: [batch_size, seq_len] | |||
:return: loss value | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
flip_mask = (mask == 0) | |||
_arc_pred = arc_pred.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_true[:, 0].fill_(-1) | |||
label_true[:, 0].fill_(-1) | |||
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||
return arc_nll + label_nll | |||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): | |||
""" | |||
max_len是包含root的 | |||
:param chars: batch_size x max_len | |||
:param ngrams: batch_size x max_len*ngram_per_char | |||
:param seq_lens: batch_size | |||
:param pre_chars: batch_size x max_len | |||
:param pre_ngrams: batch_size x max_len*ngram_per_cha | |||
:return: | |||
""" | |||
res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, | |||
pre_trigrams=pre_trigrams, gold_heads=None) | |||
output = {} | |||
output['arc_pred'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_pred'] = label_pred | |||
return output | |||
class CharParser(nn.Module): | |||
def __init__(self, char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers=3, | |||
rnn_hidden_size=400, #单向的数量 | |||
arc_mlp_size=500, | |||
label_mlp_size=100, | |||
dropout=0.3, | |||
encoder='var-lstm', | |||
use_greedy_infer=False, | |||
app_index = 0, | |||
pre_chars_embed=None, | |||
pre_bigrams_embed=None, | |||
pre_trigrams_embed=None): | |||
super().__init__() | |||
self.parser = CharBiaffineParser(char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers, | |||
rnn_hidden_size, #单向的数量 | |||
arc_mlp_size, | |||
label_mlp_size, | |||
dropout, | |||
encoder, | |||
use_greedy_infer, | |||
app_index, | |||
pre_chars_embed=pre_chars_embed, | |||
pre_bigrams_embed=pre_bigrams_embed, | |||
pre_trigrams_embed=pre_trigrams_embed) | |||
def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, | |||
pre_trigrams=None): | |||
res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, | |||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
arc_pred = res_dict['arc_pred'] | |||
label_pred = res_dict['label_pred'] | |||
masks = res_dict['mask'] | |||
loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) | |||
return {'loss': loss} | |||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): | |||
res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, | |||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
output = {} | |||
output['head_preds'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_preds'] = label_pred | |||
return output |
@@ -0,0 +1,65 @@ | |||
from fastNLP.core.callback import Callback | |||
import torch | |||
from torch import nn | |||
class OptimizerCallback(Callback): | |||
def __init__(self, optimizer, scheduler, update_every=4): | |||
super().__init__() | |||
self._optimizer = optimizer | |||
self.scheduler = scheduler | |||
self._update_every = update_every | |||
def on_backward_end(self): | |||
if self.step % self._update_every==0: | |||
# nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) | |||
# self._optimizer.step() | |||
self.scheduler.step() | |||
# self.model.zero_grad() | |||
class DevCallback(Callback): | |||
def __init__(self, tester, metric_key='u_f1'): | |||
super().__init__() | |||
self.tester = tester | |||
setattr(tester, 'verbose', 0) | |||
self.metric_key = metric_key | |||
self.record_best = False | |||
self.best_eval_value = 0 | |||
self.best_eval_res = None | |||
self.best_dev_res = None # 存取dev的表现 | |||
def on_valid_begin(self): | |||
eval_res = self.tester.test() | |||
metric_name = self.tester.metrics[0].__class__.__name__ | |||
metric_value = eval_res[metric_name][self.metric_key] | |||
if metric_value>self.best_eval_value: | |||
self.best_eval_value = metric_value | |||
self.best_epoch = self.trainer.epoch | |||
self.record_best = True | |||
self.best_eval_res = eval_res | |||
self.test_eval_res = eval_res | |||
eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ | |||
self.tester._format_eval_results(eval_res) | |||
self.pbar.write(eval_str) | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
if self.record_best: | |||
self.best_dev_res = eval_result | |||
self.record_best = False | |||
if is_better_eval: | |||
self.best_dev_res_on_dev = eval_result | |||
self.best_test_res_on_dev = self.test_eval_res | |||
self.dev_epoch = self.epoch | |||
def on_train_end(self): | |||
print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, | |||
self.tester._format_eval_results(self.best_eval_res), | |||
self.tester._format_eval_results(self.best_dev_res))) | |||
print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, | |||
self.tester._format_eval_results(self.best_test_res_on_dev), | |||
self.tester._format_eval_results(self.best_dev_res_on_dev))) |
@@ -0,0 +1,184 @@ | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.utils import seq_len_to_mask | |||
import torch | |||
class SegAppCharParseF1Metric(MetricBase): | |||
# | |||
def __init__(self, app_index): | |||
super().__init__() | |||
self.app_index = app_index | |||
self.parse_head_tp = 0 | |||
self.parse_label_tp = 0 | |||
self.rec_tol = 0 | |||
self.pre_tol = 0 | |||
def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, | |||
pun_masks): | |||
""" | |||
max_len是不包含root的character的长度 | |||
:param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size | |||
:param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size | |||
:param head_preds: batch_size x max_len | |||
:param label_preds: batch_size x max_len | |||
:param seq_lens: | |||
:param pun_masks: batch_size x | |||
:return: | |||
""" | |||
# 去掉root | |||
head_preds = head_preds[:, 1:].tolist() | |||
label_preds = label_preds[:, 1:].tolist() | |||
seq_lens = (seq_lens - 1).tolist() | |||
# 先解码出words,POS,heads, labels, 对应的character范围 | |||
for b in range(len(head_preds)): | |||
seq_len = seq_lens[b] | |||
head_pred = head_preds[b][:seq_len] | |||
label_pred = label_preds[b][:seq_len] | |||
words = [] # 存放[word_start, word_end),相对起始位置,不考虑root | |||
heads = [] | |||
labels = [] | |||
ranges = [] # 对应该char是第几个word,长度是seq_len+1 | |||
word_idx = 0 | |||
word_start_idx = 0 | |||
for idx, (label, head) in enumerate(zip(label_pred, head_pred)): | |||
ranges.append(word_idx) | |||
if label == self.app_index: | |||
pass | |||
else: | |||
labels.append(label) | |||
heads.append(head) | |||
words.append((word_start_idx, idx+1)) | |||
word_start_idx = idx+1 | |||
word_idx += 1 | |||
head_dep_tuple = [] # head在前面 | |||
head_label_dep_tuple = [] | |||
for idx, head in enumerate(heads): | |||
span = words[idx] | |||
if span[0]==span[1]-1 and pun_masks[b, span[0]]: | |||
continue # exclude punctuations | |||
if head == 0: | |||
head_dep_tuple.append((('root', words[idx]))) | |||
head_label_dep_tuple.append(('root', labels[idx], words[idx])) | |||
else: | |||
head_word_idx = ranges[head-1] | |||
head_word_span = words[head_word_idx] | |||
head_dep_tuple.append(((head_word_span, words[idx]))) | |||
head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) | |||
gold_head_dep_tuple = set(gold_word_pairs[b]) | |||
gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) | |||
for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): | |||
if head_dep in gold_head_dep_tuple: | |||
self.parse_head_tp += 1 | |||
if head_label_dep in gold_head_label_dep_tuple: | |||
self.parse_label_tp += 1 | |||
self.pre_tol += len(head_dep_tuple) | |||
self.rec_tol += len(gold_head_dep_tuple) | |||
def get_metric(self, reset=True): | |||
u_p = self.parse_head_tp / self.pre_tol | |||
u_r = self.parse_head_tp / self.rec_tol | |||
u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) | |||
l_p = self.parse_label_tp / self.pre_tol | |||
l_r = self.parse_label_tp / self.rec_tol | |||
l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) | |||
if reset: | |||
self.parse_head_tp = 0 | |||
self.parse_label_tp = 0 | |||
self.rec_tol = 0 | |||
self.pre_tol = 0 | |||
return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), | |||
'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} | |||
class CWSMetric(MetricBase): | |||
def __init__(self, app_index): | |||
super().__init__() | |||
self.app_index = app_index | |||
self.pre = 0 | |||
self.rec = 0 | |||
self.tp = 0 | |||
def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): | |||
""" | |||
:param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 | |||
:param seg_masks: batch_size x max_len,只有在word结束的地方为1 | |||
:param label_preds: batch_size x max_len | |||
:param seq_lens: batch_size | |||
:return: | |||
""" | |||
pred_masks = torch.zeros_like(seg_masks) | |||
pred_segs = torch.zeros_like(seg_targets) | |||
seq_lens = (seq_lens - 1).tolist() | |||
for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): | |||
seq_len = seq_lens[idx] | |||
label_pred = label_pred[:seq_len] | |||
word_len = 0 | |||
for l_i, label in enumerate(label_pred): | |||
if label==self.app_index and l_i!=len(label_pred)-1: | |||
word_len += 1 | |||
else: | |||
pred_segs[idx, l_i] = word_len # 这个词的长度为word_len | |||
pred_masks[idx, l_i] = 1 | |||
word_len = 0 | |||
right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 | |||
self.rec += seg_masks.sum().item() | |||
self.pre += pred_masks.sum().item() | |||
# 且pred和target在同一个地方有值 | |||
self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() | |||
def get_metric(self, reset=True): | |||
res = {} | |||
res['rec'] = round(self.tp/(self.rec+1e-6), 4) | |||
res['pre'] = round(self.tp/(self.pre+1e-6), 4) | |||
res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) | |||
if reset: | |||
self.pre = 0 | |||
self.rec = 0 | |||
self.tp = 0 | |||
return res | |||
class ParserMetric(MetricBase): | |||
def __init__(self, ): | |||
super().__init__() | |||
self.num_arc = 0 | |||
self.num_label = 0 | |||
self.num_sample = 0 | |||
def get_metric(self, reset=True): | |||
res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), | |||
'LAS': round(self.num_label*1.0 / self.num_sample, 4)} | |||
if reset: | |||
self.num_sample = self.num_label = self.num_arc = 0 | |||
return res | |||
def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): | |||
"""Evaluate the performance of prediction. | |||
""" | |||
if seq_lens is None: | |||
seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) | |||
else: | |||
seq_mask = seq_len_to_mask(seq_lens.long(), float=False) | |||
# mask out <root> tag | |||
seq_mask[:, 0] = 0 | |||
head_pred_correct = (head_preds == heads).__and__(seq_mask) | |||
label_pred_correct = (label_preds == labels).__and__(head_pred_correct) | |||
self.num_arc += head_pred_correct.float().sum().item() | |||
self.num_label += label_pred_correct.float().sum().item() | |||
self.num_sample += seq_mask.sum().item() | |||
@@ -0,0 +1,16 @@ | |||
Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) | |||
### 准备数据 | |||
1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. | |||
2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 | |||
3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 | |||
### 运行代码 | |||
``` | |||
python train.py | |||
``` | |||
### 其它 | |||
ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 | |||
learning rate scheduler. |
@@ -0,0 +1,124 @@ | |||
import sys | |||
sys.path.append('../..') | |||
from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader | |||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
from torch import nn | |||
from functools import partial | |||
from reproduction.joint_cws_parse.models.CharParser import CharParser | |||
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric | |||
from fastNLP import cache_results, BucketSampler, Trainer | |||
from torch import optim | |||
from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback | |||
from torch.optim.lr_scheduler import LambdaLR, StepLR | |||
from fastNLP import Tester | |||
from fastNLP import GradientClipCallback, LRScheduler | |||
import os | |||
def set_random_seed(random_seed=666): | |||
import random, numpy, torch | |||
random.seed(random_seed) | |||
numpy.random.seed(random_seed) | |||
torch.cuda.manual_seed(random_seed) | |||
torch.random.manual_seed(random_seed) | |||
uniform_init = partial(nn.init.normal_, std=0.02) | |||
################################################### | |||
# 需要变动的超参放到这里 | |||
lr = 0.002 # 0.01~0.001 | |||
dropout = 0.33 # 0.3~0.6 | |||
weight_decay = 0 # 1e-5, 1e-6, 0 | |||
arc_mlp_size = 500 # 200, 300 | |||
rnn_hidden_size = 400 # 200, 300, 400 | |||
rnn_layers = 3 # 2, 3 | |||
encoder = 'var-lstm' # var-lstm, lstm | |||
emb_size = 100 # 64 , 100 | |||
label_mlp_size = 100 | |||
batch_size = 32 | |||
update_every = 4 | |||
n_epochs = 100 | |||
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||
#################################################### | |||
set_random_seed(1234) | |||
device = 0 | |||
# @cache_results('caches/{}.pkl'.format(data_name)) | |||
# def get_data(): | |||
data = CTBxJointLoader().process(data_folder) | |||
char_labels_vocab = data.vocabs['char_labels'] | |||
pre_chars_vocab = data.vocabs['pre_chars'] | |||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||
chars_vocab = data.vocabs['chars'] | |||
bigrams_vocab = data.vocabs['bigrams'] | |||
trigrams_vocab = data.vocabs['trigrams'] | |||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||
print(data) | |||
model = CharParser(char_vocab_size=len(chars_vocab), | |||
emb_dim=emb_size, | |||
bigram_vocab_size=len(bigrams_vocab), | |||
trigram_vocab_size=len(trigrams_vocab), | |||
num_label=len(char_labels_vocab), | |||
rnn_layers=rnn_layers, | |||
rnn_hidden_size=rnn_hidden_size, | |||
arc_mlp_size=arc_mlp_size, | |||
label_mlp_size=label_mlp_size, | |||
dropout=dropout, | |||
encoder=encoder, | |||
use_greedy_infer=False, | |||
app_index=char_labels_vocab['APP'], | |||
pre_chars_embed=pre_chars_embed, | |||
pre_bigrams_embed=pre_bigrams_embed, | |||
pre_trigrams_embed=pre_trigrams_embed) | |||
metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) | |||
metric2 = CWSMetric(char_labels_vocab['APP']) | |||
metrics = [metric1, metric2] | |||
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, | |||
weight_decay=weight_decay, betas=[0.9, 0.9]) | |||
sampler = BucketSampler(seq_len_field_name='seq_lens') | |||
callbacks = [] | |||
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||
# callbacks.append(optim_callback) | |||
scheduler_callback = LRScheduler(scheduler) | |||
callbacks.append(scheduler_callback) | |||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||
tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, | |||
batch_size=64, device=device, verbose=0) | |||
dev_callback = DevCallback(tester) | |||
callbacks.append(dev_callback) | |||
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||
device=device, callbacks=callbacks, update_every=update_every) | |||
trainer.train() |
@@ -0,0 +1,100 @@ | |||
# Matching任务模型复现 | |||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||
复现的模型有(按论文发表时间顺序排序): | |||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||
- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | |||
- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | |||
- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). | |||
论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | |||
# 数据集及复现结果汇总 | |||
使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 | |||
'\-'表示我们仍未复现或者论文原文没有汇报 | |||
model name | SNLI | MNLI | RTE | QNLI | Quora | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||
BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | |||
*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 | |||
# 数据集复现结果及其他主要模型对比 | |||
## SNLI | |||
[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | |||
Performance on Test set: | |||
model name | ESIM | DIIN | MwAN | [GPT1.0](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | [BERT-Large+SRL](https://arxiv.org/pdf/1809.02794.pdf) | [MT-DNN](https://arxiv.org/pdf/1901.11504.pdf) | |||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | 88.0 | 88.0 | 88.3 | 89.9 | 91.3 | 91.6 | | |||
### 基于fastNLP复现的结果 | |||
Performance on Test set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
## MNLI | |||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||
Performance on Test set(matched/mismatched): | |||
model name | ESIM | DIIN | MwAN | GPT1.0 | BERT-Base | MT-DNN | |||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | 72.4/72.1 | 78.8/77.8 | 78.5/77.7 | 82.1/81.4 | 84.6/83.4 | 87.9/87.4 | | |||
### 基于fastNLP复现的结果 | |||
Performance on Test set(matched/mismatched): | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||
:---: | :---: | :---: | :---: | :---: | :---: | | |||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
## RTE | |||
Still in progress. | |||
## QNLI | |||
### From GLUE baselines | |||
[Link to GLUE leaderboard](https://gluebenchmark.com/leaderboard) | |||
Performance on Test set: | |||
#### LSTM-based | |||
model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | |||
:---: | :---: | :---: | :---: | :---: | | |||
__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | |||
*这些LSTM-based的baseline是由QNLI官方实现并测试的。 | |||
#### Transformer-based | |||
model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | |||
:---: | :---: | :---: | :---: | :---: | | |||
__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | |||
### 基于fastNLP复现的结果 | |||
Performance on __Dev__ set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 76.97 | - | 74.6 | - | |||
## Quora | |||
Still in progress. | |||
@@ -5,8 +5,8 @@ from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader): | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
def __init__(self, paths: dict=None): | |||
""" | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
self.paths = paths | |||
def _load(self, path): | |||
@@ -34,7 +33,8 @@ class MatchingLoader(DataSetLoader): | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
@@ -49,6 +49,8 @@ class MatchingLoader(DataSetLoader): | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
:param str auto_pad_token: 自动pad的内容 | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
@@ -169,6 +171,9 @@ class MatchingLoader(DataSetLoader): | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if auto_pad_length is not None: | |||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
@@ -180,7 +185,7 @@ class MatchingLoader(DataSetLoader): | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary() | |||
words_vocab = Vocabulary(padding=auto_pad_token) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
@@ -202,6 +207,20 @@ class MatchingLoader(DataSetLoader): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if auto_pad_length is not None: | |||
if seq_len_type == 'seq_len': | |||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||
new_field_name=fields, is_input=auto_set_input) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
@@ -267,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
# 'test': 'test.tsv' # test set has not label | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
@@ -281,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
ds.rename_field(k, v) | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
@@ -306,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
# 'test': 'test.tsv' # test set has not label | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
@@ -320,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
ds.rename_field(k, v) | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
@@ -332,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
@@ -348,6 +369,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
@@ -364,6 +389,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
@@ -376,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
@@ -0,0 +1,102 @@ | |||
import random | |||
import numpy as np | |||
import torch | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader | |||
from reproduction.matching.model.bert import BertForNLI | |||
# define hyper-parameters | |||
class BERTConfig: | |||
task = 'snli' | |||
batch_size_per_gpu = 6 | |||
n_epochs = 6 | |||
lr = 2e-5 | |||
seq_len_type = 'bert' | |||
seed = 42 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||
arg = BERTConfig() | |||
# set random seed | |||
random.seed(arg.seed) | |||
np.random.seed(arg.seed) | |||
torch.manual_seed(arg.seed) | |||
n_gpu = torch.cuda.device_count() | |||
if n_gpu > 0: | |||
torch.cuda.manual_seed_all(arg.seed) | |||
# load data set | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
# define model | |||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
save_path=arg.save_path) | |||
# train model | |||
trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
# test model | |||
tester.test() | |||
@@ -1,47 +1,103 @@ | |||
import argparse | |||
import random | |||
import numpy as np | |||
import torch | |||
from torch.optim import Adamax | |||
from torch.optim.lr_scheduler import StepLR | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader | |||
from reproduction.matching.model.esim import ESIMModel | |||
argument = argparse.ArgumentParser() | |||
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||
argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||
argument.add_argument('--n-epochs', type=int, default=100) | |||
argument.add_argument('--lr', type=float, default=1e-4) | |||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||
argument.add_argument('--save-dir', type=str, default=None) | |||
arg = argument.parse_args() | |||
bert_dirs = 'path/to/bert/dir' | |||
# define hyper-parameters | |||
class ESIMConfig: | |||
task = 'snli' | |||
embedding = 'glove' | |||
batch_size_per_gpu = 196 | |||
n_epochs = 30 | |||
lr = 2e-3 | |||
seq_len_type = 'seq_len' | |||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||
# mask表示用0/1掩码矩阵来表示长度信息; | |||
seed = 42 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
arg = ESIMConfig() | |||
# set random seed | |||
random.seed(arg.seed) | |||
np.random.seed(arg.seed) | |||
torch.manual_seed(arg.seed) | |||
n_gpu = torch.cuda.device_count() | |||
if n_gpu > 0: | |||
torch.cuda.manual_seed_all(arg.seed) | |||
# load data set | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, | |||
) | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
# load embedding | |||
if arg.embedding == 'elmo': | |||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
elif arg.embedding == 'glove': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||
else: | |||
raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||
# define model | |||
model = ESIMModel(embedding) | |||
model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||
# define optimizer and callback | |||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 | |||
callbacks = [ | |||
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) | |||
LRScheduler(scheduler), | |||
] | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
optimizer=optimizer, | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
data=data_info.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
@@ -81,6 +81,7 @@ class ESIMModel(BaseModel): | |||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||
logits = torch.tanh(self.classifier(out)) | |||
# logits = self.classifier(out) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel): | |||
return {Const.OUTPUT: logits} | |||
def predict(self, **kwargs): | |||
return self.forward(**kwargs) | |||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||
return {Const.OUTPUT: pred} | |||
# input [batch_size, len , hidden] | |||
# mask [batch_size, len] (111...00) | |||
@@ -127,7 +129,7 @@ class BiRNN(nn.Module): | |||
def forward(self, x, x_mask): | |||
# Sort x | |||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||
lengths = x_mask.data.eq(1).long().sum(1) | |||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||
lengths = list(lengths[idx_sort]) | |||
@@ -0,0 +1,13 @@ | |||
# NER任务模型复现 | |||
这里使用fastNLP复现经典的BiLSTM-CNN的NER任务的模型,旨在达到与论文中相符的性能。 | |||
论文链接[Named Entity Recognition with Bidirectional LSTM-CNNs](https://arxiv.org/pdf/1511.08308.pdf) | |||
# 数据集及复现结果汇总 | |||
使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道) | |||
model name | Conll2003 | Ontonotes | |||
:---: | :---: | :---: | |||
BiLSTM-CNN | 91.17/90.91 | 86.47/86.35 | | |||
@@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
} | |||
如果paths为不合法的,将直接进行raise相应的错误 | |||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:return: | |||
""" | |||
if isinstance(paths, str): | |||
if os.path.isfile(paths): | |||
return {'train': paths} | |||
elif os.path.isdir(paths): | |||
train_fp = os.path.join(paths, 'train.txt') | |||
if not os.path.isfile(train_fp): | |||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | |||
files = {'train': train_fp} | |||
for filename in ['dev.txt', 'test.txt']: | |||
fp = os.path.join(paths, filename) | |||
if os.path.isfile(fp): | |||
files[filename.split('.')[0]] = fp | |||
filenames = os.listdir(paths) | |||
files = {} | |||
for filename in filenames: | |||
path_pair = None | |||
if 'train' in filename: | |||
path_pair = ('train', filename) | |||
if 'dev' in filename: | |||
if path_pair: | |||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('dev', filename) | |||
if 'test' in filename: | |||
if path_pair: | |||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('test', filename) | |||
if path_pair: | |||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||
return files | |||
else: | |||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||
@@ -1,5 +1,5 @@ | |||
numpy | |||
torch>=0.4.0 | |||
tqdm | |||
nltk | |||
numpy>=1.14.2 | |||
torch>=1.0.0 | |||
tqdm>=4.28.1 | |||
nltk>=3.4.1 | |||
requests |
@@ -1,7 +1,7 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | |||
from fastNLP.io.dataset_loader import SSTLoader | |||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader | |||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||
from reproduction.text_classification.data.yelpLoader import yelpLoader | |||
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): | |||
print(info.vocabs) | |||
print(info.datasets) | |||
os.remove(train), os.remove(test) | |||
def test_import(self): | |||
import fastNLP | |||
from fastNLP.io import SNLILoader | |||
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||
get_index=True, seq_len_type='seq_len') | |||
assert 'train' in ds.datasets | |||
assert len(ds.datasets) == 1 | |||
assert len(ds.datasets['train']) == 3 |
@@ -8,8 +8,9 @@ from fastNLP.models.bert import * | |||
class TestBert(unittest.TestCase): | |||
def test_bert_1(self): | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder._bert import BertConfig | |||
model = BertForSequenceClassification(2) | |||
model = BertForSequenceClassification(2, BertConfig(32000)) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
@@ -22,8 +23,9 @@ class TestBert(unittest.TestCase): | |||
def test_bert_2(self): | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder._bert import BertConfig | |||
model = BertForMultipleChoice(2) | |||
model = BertForMultipleChoice(2, BertConfig(32000)) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
@@ -36,8 +38,9 @@ class TestBert(unittest.TestCase): | |||
def test_bert_3(self): | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder._bert import BertConfig | |||
model = BertForTokenClassification(7) | |||
model = BertForTokenClassification(7, BertConfig(32000)) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
@@ -50,8 +53,9 @@ class TestBert(unittest.TestCase): | |||
def test_bert_4(self): | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder._bert import BertConfig | |||
model = BertForQuestionAnswering() | |||
model = BertForQuestionAnswering(BertConfig(32000)) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
@@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel | |||
class TestBert(unittest.TestCase): | |||
def test_bert_1(self): | |||
model = BertModel(vocab_size=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
from fastNLP.modules.encoder._bert import BertConfig | |||
config = BertConfig(32000) | |||
model = BertModel(config) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
@@ -18,4 +19,4 @@ class TestBert(unittest.TestCase): | |||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||
for layer in all_encoder_layers: | |||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) | |||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |
@@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase): | |||
train_data.rename_field('label', 'label_seq') | |||
test_data.rename_field('label', 'label_seq') | |||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||
loss = CrossEntropyLoss(target="label_seq") | |||
metric = AccuracyMetric(target="label_seq") | |||
# 实例化Trainer,传入模型和数据,进行训练 | |||
@@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase): | |||
# 用train_data训练,在test_data验证 | |||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
loss=CrossEntropyLoss(target="label_seq"), | |||
metrics=AccuracyMetric(target="label_seq"), | |||
save_path=None, | |||
batch_size=32, | |||