diff --git a/README.md b/README.md index 9d949482..a5ce3c64 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,14 @@ ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) -fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性: +fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: -- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。 -- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等; -- 详尽的中文文档以供查阅; +- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; +- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; +- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等; +- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅; - 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; -- 封装CNNText,Biaffine等模型可供直接使用; +- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/) - 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 @@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地 fastNLP 依赖如下包: -+ numpy -+ torch>=0.4.0 -+ tqdm -+ nltk ++ numpy>=1.14.2 ++ torch>=1.0.0 ++ tqdm>=4.28.1 ++ nltk>=3.4.1 ++ requests -其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。 -在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装 +其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 +在依赖包安装完成后,您可以在命令行执行如下指令完成安装 ```shell pip install fastNLP @@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助 fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 你可以在以下两个地方查看相关信息 -- [介绍](reproduction/) -- [源码](fastNLP/models/) +- [模型介绍](reproduction/) +- [模型源码](fastNLP/models/) ## 项目结构 @@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: fastNLP.core - 实现了核心功能,包括数据处理组件、训练器、测速器等 + 实现了核心功能,包括数据处理组件、训练器、测试器等 fastNLP.models diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py new file mode 100644 index 00000000..4a7757d3 --- /dev/null +++ b/fastNLP/core/_parallel_utils.py @@ -0,0 +1,88 @@ + +import threading +import torch +from torch.nn.parallel.parallel_apply import get_a_var + +from torch.nn.parallel.scatter_gather import scatter_kwargs, gather +from torch.nn.parallel.replicate import replicate + + +def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): + r"""Applies each `module` in :attr:`modules` in parallel on arguments + contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) + on each of :attr:`devices`. + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, kwargs, device=None): + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = getattr(module, func_name)(*input, **kwargs) + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in + enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs + + +def _data_parallel_wrapper(func_name, device_ids, output_device): + """ + 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 + + :param str, func_name: 对network中的这个函数进行多卡运行 + :param device_ids: nn.DataParallel中的device_ids + :param output_device: nn.DataParallel中的output_device + :return: + """ + def wrapper(network, *inputs, **kwargs): + inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) + if len(device_ids) == 1: + return getattr(network, func_name)(*inputs[0], **kwargs[0]) + replicas = replicate(network, device_ids[:len(inputs)]) + outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) + return gather(outputs, output_device) + return wrapper diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 5dfd889b..0b1890f8 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -113,7 +113,7 @@ class Callback(object): @property def n_steps(self): - """Trainer一共会运行多少步""" + """Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" return self._trainer.n_steps @property @@ -181,7 +181,7 @@ class Callback(object): :param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。 :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 - 情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 + 情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。 :return: """ pass @@ -399,10 +399,11 @@ class GradientClipCallback(Callback): self.clip_value = clip_value def on_backward_end(self): - if self.parameters is None: - self.clip_fun(self.model.parameters(), self.clip_value) - else: - self.clip_fun(self.parameters, self.clip_value) + if self.step%self.update_every==0: + if self.parameters is None: + self.clip_fun(self.model.parameters(), self.clip_value) + else: + self.clip_fun(self.parameters, self.clip_value) class EarlyStopCallback(Callback): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 526bf37a..14aacef0 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -20,6 +20,7 @@ from collections import defaultdict import torch import torch.nn.functional as F +from ..core.const import Const from .utils import _CheckError from .utils import _CheckRes from .utils import _build_args @@ -28,6 +29,7 @@ from .utils import _check_function_or_method from .utils import _get_func_signature from .utils import seq_len_to_mask + class LossBase(object): """ 所有loss的基类。如果想了解其中的原理,请查看源码。 @@ -95,22 +97,7 @@ class LossBase(object): # if func_spect.varargs: # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " # f"positional argument.).") - - def _fast_param_map(self, pred_dict, target_dict): - """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. - such as pred_dict has one element, target_dict has one element - :param pred_dict: - :param target_dict: - :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. - """ - fast_param = {} - if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(target_dict.values())[0] - return fast_param - return fast_param - def __call__(self, pred_dict, target_dict, check=False): """ :param dict pred_dict: 模型的forward函数返回的dict @@ -118,11 +105,7 @@ class LossBase(object): :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 :return: """ - fast_param = self._fast_param_map(pred_dict, target_dict) - if fast_param: - loss = self.get_loss(**fast_param) - return loss - + if not self._checked: # 1. check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.get_loss) @@ -212,7 +195,6 @@ class LossFunc(LossBase): if not isinstance(key_map, dict): raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") self._init_param_map(key_map, **kwargs) - class CrossEntropyLoss(LossBase): @@ -226,6 +208,7 @@ class CrossEntropyLoss(LossBase): :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 传入seq_len. + :param str reduction: 支持'mean','sum'和'none'. Example:: @@ -233,21 +216,25 @@ class CrossEntropyLoss(LossBase): """ - def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): + def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): super(CrossEntropyLoss, self).__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction def get_loss(self, pred, target, seq_len=None): - if pred.dim()>2: - pred = pred.view(-1, pred.size(-1)) - target = target.view(-1) + if pred.dim() > 2: + if pred.size(1) != target.size(1): + pred = pred.transpose(1, 2) + pred = pred.reshape(-1, pred.size(-1)) + target = target.reshape(-1) if seq_len is not None: - mask = seq_len_to_mask(seq_len).view(-1).eq(0) + mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, - ignore_index=self.padding_idx) + ignore_index=self.padding_idx, reduction=self.reduction) class L1Loss(LossBase): @@ -258,15 +245,18 @@ class L1Loss(LossBase): :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` + :param str reduction: 支持'mean','sum'和'none'. """ - def __init__(self, pred=None, target=None): + def __init__(self, pred=None, target=None, reduction='mean'): super(L1Loss, self).__init__() self._init_param_map(pred=pred, target=target) + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction def get_loss(self, pred, target): - return F.l1_loss(input=pred, target=target) + return F.l1_loss(input=pred, target=target, reduction=self.reduction) class BCELoss(LossBase): @@ -277,14 +267,17 @@ class BCELoss(LossBase): :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + :param str reduction: 支持'mean','sum'和'none'. """ - def __init__(self, pred=None, target=None): + def __init__(self, pred=None, target=None, reduction='mean'): super(BCELoss, self).__init__() self._init_param_map(pred=pred, target=target) + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction def get_loss(self, pred, target): - return F.binary_cross_entropy(input=pred, target=target) + return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) class NLLLoss(LossBase): @@ -295,14 +288,20 @@ class NLLLoss(LossBase): :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 + 传入seq_len. + :param str reduction: 支持'mean','sum'和'none'. """ - def __init__(self, pred=None, target=None): + def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): super(NLLLoss, self).__init__() self._init_param_map(pred=pred, target=target) + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction + self.ignore_idx = ignore_idx def get_loss(self, pred, target): - return F.nll_loss(input=pred, target=target) + return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) class LossInForward(LossBase): @@ -314,7 +313,7 @@ class LossInForward(LossBase): :param str loss_key: 在forward函数中loss的键名,默认为loss """ - def __init__(self, loss_key='loss'): + def __init__(self, loss_key=Const.LOSS): super().__init__() if not isinstance(loss_key, str): raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 0849b35d..1fe035bf 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -9,6 +9,9 @@ __all__ = [ ] import torch +import math +import torch +from torch.optim.optimizer import Optimizer as TorchOptimizer class Optimizer(object): @@ -97,3 +100,110 @@ class Adam(Optimizer): return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings) else: return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings) + + +class AdamW(TorchOptimizer): + r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index ce016bb6..2d6a7380 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device class Predictor(object): """ - An interface for predicting outputs based on trained models. + 一个根据训练模型预测输出的预测器(Predictor) - It does not care about evaluations of the model, which is different from Tester. - This is a high-level model wrapper to be called by FastNLP. - This class does not share any operations with Trainer and Tester. - Currently, Predictor does not support GPU. + 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 + 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 + + :param torch.nn.Module network: 用来完成预测任务的模型 """ def __init__(self, network): @@ -30,18 +30,19 @@ class Predictor(object): self.batch_size = 1 self.batch_output = [] - def predict(self, data, seq_len_field_name=None): - """Perform inference using the trained model. + def predict(self, data: DataSet, seq_len_field_name=None): + """用已经训练好的模型进行inference. - :param data: a DataSet object. - :param str seq_len_field_name: field name indicating sequence lengths - :return: list of batch outputs + :param fastNLP.DataSet data: 待预测的数据集 + :param str seq_len_field_name: 表示序列长度信息的field名字 + :return: dict dict里面的内容为模型预测的结果 """ if not isinstance(data, DataSet): raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) + prev_training = self.network.training self.network.eval() network_device = _get_model_device(self.network) batch_output = defaultdict(list) @@ -74,4 +75,5 @@ class Predictor(object): else: batch_output[key].append(value) + self.network.train(prev_training) return batch_output diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 4cdd4ffb..7048d0ae 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ -import warnings - import torch import torch.nn as nn @@ -48,6 +46,8 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device +from ._parallel_utils import _data_parallel_wrapper +from functools import partial __all__ = [ "Tester" @@ -104,26 +104,27 @@ class Tester(object): self.data_iterator = data else: raise TypeError("data type {} not support".format(type(data))) - - # 如果是DataParallel将没有办法使用predict方法 - if isinstance(self._model, nn.DataParallel): - if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): - warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," - " while DataParallel has no predict() function.") - self._model = self._model.module - + # check predict - if hasattr(self._model, 'predict'): - self._predict_func = self._model.predict - if not callable(self._predict_func): - _model_name = model.__class__.__name__ - raise TypeError(f"`{_model_name}.predict` must be callable to be used " - f"for evaluation, not `{type(self._predict_func)}`.") + if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ + (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and + callable(self._model.module.predict)): + if isinstance(self._model, nn.DataParallel): + self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', + self._model.device_ids, + self._model.output_device), + network=self._model.module) + self._predict_func = self._model.module.predict + else: + self._predict_func = self._model.predict + self._predict_func_wrapper = self._model.predict else: - if isinstance(model, nn.DataParallel): + if isinstance(self._model, nn.DataParallel): + self._predict_func_wrapper = self._model.forward self._predict_func = self._model.module.forward else: self._predict_func = self._model.forward + self._predict_func_wrapper = self._model.forward def test(self): """开始进行验证,并返回验证结果。 @@ -180,7 +181,7 @@ class Tester(object): def _data_forward(self, func, x): """A forward pass of the model. """ x = _build_args(func, **x) - y = func(**x) + y = self._predict_func_wrapper(**x) return y def _format_eval_results(self, results): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6edeb4a0..eabda99c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -454,7 +454,7 @@ class Trainer(object): if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, - metric_key=metric_key, check_level=check_code_level, + metric_key=self.metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 self.model = _move_model_to_device(model, device=device) @@ -473,7 +473,7 @@ class Trainer(object): self.best_dev_step = None self.best_dev_perf = None self.n_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * self.n_epochs + len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d26df966..490f9f8f 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -17,7 +17,6 @@ import numpy as np import torch import torch.nn as nn - _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -278,6 +277,7 @@ def _move_model_to_device(model, device): return model + def _get_model_device(model): """ 传入一个nn.Module的模型,获取它所在的device diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 28f466a8..05d75f43 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -11,21 +11,35 @@ """ __all__ = [ 'EmbedLoader', - + + 'DataInfo', 'DataSetLoader', + 'CSVLoader', 'JsonLoader', 'ConllLoader', - 'SNLILoader', - 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', 'ModelLoader', 'ModelSaver', + + 'SSTLoader', + + 'MatchingLoader', + 'SNLILoader', + 'MNLILoader', + 'QNLILoader', + 'QuoraLoader', + 'RTELoader', ] from .embed_loader import EmbedLoader -from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ - SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader +from .base_loader import DataInfo, DataSetLoader +from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \ + PeopleDailyCorpusLoader, Conll2003Loader from .model_io import ModelLoader, ModelSaver + +from .data_loader.sst import SSTLoader +from .data_loader.matching import MatchingLoader, SNLILoader, \ + MNLILoader, QNLILoader, QuoraLoader, RTELoader diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 465fb7e8..8cff1da1 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -10,6 +10,7 @@ from typing import Union, Dict import os from ..core.dataset import DataSet + class BaseLoader(object): """ 各个 Loader 的基类,提供了 API 的参考。 @@ -55,8 +56,6 @@ class BaseLoader(object): return obj - - def _download_from_url(url, path): try: from tqdm.auto import tqdm @@ -115,13 +114,11 @@ class DataInfo: 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict - :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict """ - def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): + def __init__(self, vocabs: dict = None, datasets: dict = None): self.vocabs = vocabs or {} - self.embeddings = embeddings or {} self.datasets = datasets or {} def __repr__(self): @@ -133,6 +130,7 @@ class DataInfo: _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str + class DataSetLoader: """ 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` @@ -213,7 +211,6 @@ class DataSetLoader: 返回的 :class:`DataInfo` 对象有如下属性: - vocabs: 由从数据集中获取的词表组成的字典,每个词表 - - embeddings: (可选) 数据集对应的词嵌入 - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` :param paths: 原始数据读取的路径 diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py new file mode 100644 index 00000000..6f4dd973 --- /dev/null +++ b/fastNLP/io/data_loader/__init__.py @@ -0,0 +1,19 @@ +""" +用于读数据集的模块, 具体包括: + +这些模块的使用方法如下: +""" +__all__ = [ + 'SSTLoader', + + 'MatchingLoader', + 'SNLILoader', + 'MNLILoader', + 'QNLILoader', + 'QuoraLoader', + 'RTELoader', +] + +from .sst import SSTLoader +from .matching import MatchingLoader, SNLILoader, \ + MNLILoader, QNLILoader, QuoraLoader, RTELoader diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py new file mode 100644 index 00000000..3d131bcb --- /dev/null +++ b/fastNLP/io/data_loader/matching.py @@ -0,0 +1,430 @@ +import os + +from typing import Union, Dict + +from ...core.const import Const +from ...core.vocabulary import Vocabulary +from ..base_loader import DataInfo, DataSetLoader +from ..dataset_loader import JsonLoader, CSVLoader +from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR +from ...modules.encoder._bert import BertTokenizer + + +class MatchingLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + + 读取Matching任务的数据集 + + :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 + """ + + def __init__(self, paths: dict=None): + self.paths = paths + + def _load(self, path): + """ + :param str path: 待读取数据集的路径名 + :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 + 的原始字符串文本,第三个为标签 + """ + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, + to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, + cut_text: int = None, get_index=True, auto_pad_length: int=None, + auto_pad_token: str='', set_input: Union[list, str, bool]=True, + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + """ + :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, + 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 + 对应的全路径文件名。 + :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 + 这个数据集的名字,如果不定义则默认为train。 + :param bool to_lower: 是否将文本自动转为小写。默认值为False。 + :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : + 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 + attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len + :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 + :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 + :param bool get_index: 是否需要根据词表将文本转为index + :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad + :param str auto_pad_token: 自动pad的内容 + :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False + 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, + 于此同时其他field不会被设置为input。默认值为True。 + :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 + :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。 + 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 + 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. + :return: + """ + if isinstance(set_input, str): + set_input = [set_input] + if isinstance(set_target, str): + set_target = [set_target] + if isinstance(set_input, bool): + auto_set_input = set_input + else: + auto_set_input = False + if isinstance(set_target, bool): + auto_set_target = set_target + else: + auto_set_target = False + if isinstance(paths, str): + if os.path.isdir(paths): + path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} + else: + path = {dataset_name if dataset_name is not None else 'train': paths} + else: + path = paths + + data_info = DataInfo() + for data_name in path.keys(): + data_info.datasets[data_name] = self._load(path[data_name]) + + for data_name, data_set in data_info.datasets.items(): + if auto_set_input: + data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) + if auto_set_target: + if Const.TARGET in data_set.get_field_names(): + data_set.set_target(Const.TARGET) + + if to_lower: + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), + is_input=auto_set_input) + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), + is_input=auto_set_input) + + if bert_tokenizer is not None: + if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: + PRETRAIN_URL = _get_base_url('bert') + model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] + model_url = PRETRAIN_URL + model_name + model_dir = cached_path(model_url) + # 检查是否存在 + elif os.path.isdir(bert_tokenizer): + model_dir = bert_tokenizer + else: + raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") + + words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + words_vocab.add_word_lst(lines) + words_vocab.build_vocab() + + tokenizer = BertTokenizer.from_pretrained(model_dir) + + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, + is_input=auto_set_input) + + if isinstance(concat, bool): + concat = 'default' if concat else None + if concat is not None: + if isinstance(concat, str): + CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], + 'default': ['', '', '', '']} + if concat.lower() in CONCAT_MAP: + concat = CONCAT_MAP[concat] + else: + concat = 4 * [concat] + assert len(concat) == 4, \ + f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ + f'the end of first sentence, the begin of second sentence, and the end of second' \ + f'sentence. Your input is {concat}' + + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + + x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) + data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, + is_input=auto_set_input) + + if seq_len_type is not None: + if seq_len_type == 'seq_len': # + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), + is_input=auto_set_input) + elif seq_len_type == 'mask': + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [1] * len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), + is_input=auto_set_input) + elif seq_len_type == 'bert': + for data_name, data_set in data_info.datasets.items(): + if Const.INPUT not in data_set.get_field_names(): + raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' + f'got {data_set.get_field_names()}') + data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) + data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), + new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + + if auto_pad_length is not None: + cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) + + if cut_text is not None: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): + data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, + is_input=auto_set_input) + + data_set_list = [d for n, d in data_info.datasets.items()] + assert len(data_set_list) > 0, f'There are NO data sets in data info!' + + if bert_tokenizer is None: + words_vocab = Vocabulary(padding=auto_pad_token) + words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=[n for n in data_set_list[0].get_field_names() + if (Const.INPUT in n)], + no_create_entry_dataset=[d for n, d in data_info.datasets.items() + if 'train' not in n]) + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=Const.TARGET) + data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} + + if get_index: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, + is_input=auto_set_input) + + if Const.TARGET in data_set.get_field_names(): + data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, + is_input=auto_set_input, is_target=auto_set_target) + + if auto_pad_length is not None: + if seq_len_type == 'seq_len': + raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' + f'so the seq_len_type cannot be `{seq_len_type}`!') + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * + (auto_pad_length - len(x[fields])), new_field_name=fields, + is_input=auto_set_input) + elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): + data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), + new_field_name=fields, is_input=auto_set_input) + + for data_name, data_set in data_info.datasets.items(): + if isinstance(set_input, list): + data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) + if isinstance(set_target, list): + data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) + + return data_info + + +class SNLILoader(MatchingLoader, JsonLoader): + """ + 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, paths: dict=None): + fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + paths = paths if paths is not None else { + 'train': 'snli_1.0_train.jsonl', + 'dev': 'snli_1.0_dev.jsonl', + 'test': 'snli_1.0_test.jsonl'} + MatchingLoader.__init__(self, paths=paths) + JsonLoader.__init__(self, fields=fields) + + def _load(self, path): + ds = JsonLoader._load(self, path) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class RTELoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` + + 读取RTE数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'sentence1': Const.INPUTS(0), + 'sentence2': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if v in ds.get_field_names(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds + + +class QNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` + + 读取QNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'question': Const.INPUTS(0), + 'sentence': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if v in ds.get_field_names(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds + + +class MNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev_matched': 'dev_matched.tsv', + 'dev_mismatched': 'dev_mismatched.tsv', + 'test_matched': 'test_matched.tsv', + 'test_mismatched': 'test_mismatched.tsv', + # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', + # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t') + self.fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + + if Const.TARGET in ds.get_field_names(): + if ds[0][Const.TARGET] == 'hidden': + ds.delete_field(Const.TARGET) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + if Const.TARGET in ds.get_field_names(): + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class QuoraLoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) + + def _load(self, path): + ds = CSVLoader._load(self, path) + return ds diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 1e1b8bef..021a79b7 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -1,11 +1,14 @@ from typing import Iterable from nltk import Tree +import spacy from ..base_loader import DataInfo, DataSetLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance from ..embed_loader import EmbeddingOption, EmbedLoader +spacy.prefer_gpu() +sptk = spacy.load('en') class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader): def _get_one(data, subtree): tree = Tree.fromstring(data) if subtree: - return [(t.leaves(), t.label()) for t in tree.subtrees()] - return [(tree.leaves(), tree.label())] + return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] + return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] def process(self, paths, diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 558fe20e..2881e6e9 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -16,8 +16,6 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', - 'SNLILoader', - 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', ] @@ -30,7 +28,6 @@ from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll from .base_loader import DataSetLoader, DataInfo -from .data_loader.sst import SSTLoader from ..core.const import Const from ..modules.encoder._bert import BertTokenizer @@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): else: instance = Instance(words=sent_words) data_set.append(instance) - data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") + data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) return data_set @@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader): return ds -class SNLILoader(JsonLoader): - """ - 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` - - 读取SNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - """ - - def __init__(self): - fields = { - 'sentence1_parse': Const.INPUTS(0), - 'sentence2_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - super(SNLILoader, self).__init__(fields=fields) - - def _load(self, path): - ds = super(SNLILoader, self)._load(path) - - def parse_tree(x): - t = Tree.fromstring(x) - return t.leaves() - - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - class CSVLoader(DataSetLoader): """ 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 34b5d7c0..0ae0a319 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: continue - raise ValueError('invalid instance at line: {}'.format(line_idx)) + raise ValueError('invalid instance ends at line: {}'.format(line_idx)) elif line.startswith('#'): continue else: @@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: return - print('invalid instance at line: {}'.format(line_idx)) + print('invalid instance ends at line: {}'.format(line_idx)) raise e diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 4846c7fa..fb186ce4 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -8,35 +8,7 @@ from torch import nn from .base_model import BaseModel from ..core.const import Const from ..modules.encoder import BertModel - - -class BertConfig: - - def __init__( - self, - vocab_size=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02 - ): - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range +from ..modules.encoder._bert import BertConfig class BertForSequenceClassification(BaseModel): @@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel): self.bert = BertModel.from_pretrained(bert_dir) else: if config is None: - config = BertConfig() - self.bert = BertModel(**config.__dict__) + config = BertConfig(30522) + self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) + @classmethod + def from_pretrained(cls, num_labels, pretrained_model_dir): + config = BertConfig(pretrained_model_dir) + model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) + return model + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) @@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel): self.bert = BertModel.from_pretrained(bert_dir) else: if config is None: - config = BertConfig() - self.bert = BertModel(**config.__dict__) + config = BertConfig(30522) + self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) + @classmethod + def from_pretrained(cls, num_choices, pretrained_model_dir): + config = BertConfig(pretrained_model_dir) + model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) + return model + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) @@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel): self.bert = BertModel.from_pretrained(bert_dir) else: if config is None: - config = BertConfig() - self.bert = BertModel(**config.__dict__) + config = BertConfig(30522) + self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) + @classmethod + def from_pretrained(cls, num_labels, pretrained_model_dir): + config = BertConfig(pretrained_model_dir) + model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) + return model + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_output = self.dropout(sequence_output) @@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel): self.bert = BertModel.from_pretrained(bert_dir) else: if config is None: - config = BertConfig() - self.bert = BertModel(**config.__dict__) + config = BertConfig(30522) + self.bert = BertModel(config) # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.qa_outputs = nn.Linear(config.hidden_size, 2) + @classmethod + def from_pretrained(cls, pretrained_model_dir): + config = BertConfig(pretrained_model_dir) + model = cls(config=config, bert_dir=pretrained_model_dir) + return model + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) logits = self.qa_outputs(sequence_output) diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index 4c944a54..1aba5a8c 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -46,7 +46,7 @@ class StarTransEnc(nn.Module): super(StarTransEnc, self).__init__() self.embedding = get_embeddings(init_embed) emb_dim = self.embedding.embedding_dim - self.emb_fc = nn.Linear(emb_dim, hidden_size) + #self.emb_fc = nn.Linear(emb_dim, hidden_size) self.emb_drop = nn.Dropout(emb_dropout) self.encoder = StarTransformer(hidden_size=hidden_size, num_layers=num_layers, @@ -65,7 +65,7 @@ class StarTransEnc(nn.Module): [batch, hidden] 全局 relay 节点, 详见论文 """ x = self.embedding(x) - x = self.emb_fc(self.emb_drop(x)) + #x = self.emb_fc(self.emb_drop(x)) nodes, relay = self.encoder(x, mask) return nodes, relay @@ -205,7 +205,7 @@ class STSeqCls(nn.Module): max_len=max_len, emb_dropout=emb_dropout, dropout=dropout) - self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) + self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) def forward(self, words, seq_len): """ diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py index c1579224..418b3a77 100644 --- a/fastNLP/modules/decoder/mlp.py +++ b/fastNLP/modules/decoder/mlp.py @@ -15,7 +15,8 @@ class MLP(nn.Module): 多层感知器 :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 - :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu + :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 + sigmoid,默认值为relu :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 :param str initial_method: 参数初始化方式 :param float dropout: dropout概率,默认值为0 diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index 254917e5..61a5d7d1 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -2,7 +2,8 @@ """ -这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 +这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 + 有用,也请引用一下他们。 """ @@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary import collections import unicodedata -from ...io.file_utils import _get_base_url, cached_path import numpy as np from itertools import chain import copy @@ -22,9 +22,106 @@ import os import torch from torch import nn import glob +import sys CONFIG_FILE = 'bert_config.json' -MODEL_WEIGHTS = 'pytorch_model.bin' + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) def gelu(x): @@ -40,6 +137,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) @@ -53,16 +152,18 @@ class BertLayerNorm(nn.Module): class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.size(1) @@ -82,21 +183,21 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + def __init__(self, config): super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: + if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads)) - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(attention_probs_dropout_prob) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -133,11 +234,11 @@ class BertSelfAttention(nn.Module): class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, hidden_dropout_prob): + def __init__(self, config): super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -147,10 +248,10 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + def __init__(self, config): super(BertAttention, self).__init__() - self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) - self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask): self_output = self.self(input_tensor, attention_mask) @@ -159,11 +260,13 @@ class BertAttention(nn.Module): class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_act): + def __init__(self, config): super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = ACT2FN[hidden_act] \ - if isinstance(hidden_act, str) else hidden_act + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -172,11 +275,11 @@ class BertIntermediate(nn.Module): class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + def __init__(self, config): super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -186,13 +289,11 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act): + def __init__(self, config): super(BertLayer, self).__init__() - self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob) - self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) - self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) def forward(self, hidden_states, attention_mask): attention_output = self.attention(hidden_states, attention_mask) @@ -202,13 +303,10 @@ class BertLayer(nn.Module): class BertEncoder(nn.Module): - def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob, - intermediate_size, hidden_act): + def __init__(self, config): super(BertEncoder, self).__init__() - layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] @@ -222,9 +320,9 @@ class BertEncoder(nn.Module): class BertPooler(nn.Module): - def __init__(self, hidden_size): + def __init__(self, config): super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): @@ -242,13 +340,19 @@ class BertModel(nn.Module): 如果你想使用预训练好的权重矩阵,请在以下网址下载. sources:: - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", + 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" 用预训练权重矩阵来建立BERT模型:: @@ -272,34 +376,30 @@ class BertModel(nn.Module): :param int initializer_range: 初始化权重范围,默认值为0.02 """ - def __init__(self, vocab_size=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02): + def __init__(self, config, *inputs, **kwargs): super(BertModel, self).__init__() - self.hidden_size = hidden_size - self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, - type_vocab_size, hidden_dropout_prob) - self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, - attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, - hidden_act) - self.pooler = BertPooler(hidden_size) - self.initializer_range = initializer_range - + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + super(BertModel, self).__init__() + self.config = config + self.hidden_size = self.config.hidden_size + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def init_bert_weights(self, module): + """ Initialize the weights. + """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.initializer_range) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -338,14 +438,19 @@ class BertModel(nn.Module): return encoded_layers, pooled_output @classmethod - def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) # Load config config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) - config = json.load(open(config_file, "r")) - # config = BertConfig.from_json_file(config_file) + config = BertConfig.from_json_file(config_file) # logger.info("Model config {}".format(config)) # Instantiate model. - model = cls(*inputs, **config, **kwargs) + model = cls(config, *inputs, **kwargs) if state_dict is None: files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) if len(files)==0: @@ -353,7 +458,7 @@ class BertModel(nn.Module): elif len(files)>1: raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") weights_path = files[0] - state_dict = torch.load(weights_path) + state_dict = torch.load(weights_path, map_location='cpu') old_keys = [] new_keys = [] @@ -464,6 +569,7 @@ class WordpieceTokenizer(object): output_tokens.extend(sub_tokens) return output_tokens + def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() @@ -594,6 +700,7 @@ class BasicTokenizer(object): output.append(char) return "".join(output) + def _is_whitespace(char): """Checks whether `chars` is a whitespace character.""" # \t, \n, and \r are technically contorl characters but we treat them @@ -840,6 +947,7 @@ class _WordBertModel(nn.Module): word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) attn_masks[i, :len(word_pieces_i)+2].fill_(1) + # TODO 截掉长度超过的部分。 # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index c48cb806..005cfe75 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") # 读取embedding - embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, + embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, normalize=normalize) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, _weight=embedding) - if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk - words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) - for word, idx in vocab: - if vocab._is_word_no_create_entry(word) and not hit_flags[idx]: - words_to_words[idx] = vocab.unknown_idx - self.words_to_words = words_to_words self._embed_size = self.embedding.weight.size(1) self.requires_grad = requires_grad @@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding): else: dim = len(parts) - 1 f.seek(0) - matrix = torch.zeros(len(vocab), dim) - if init_method is not None: - init_method(matrix) - hit_flags = np.zeros(len(vocab), dtype=bool) + matrix = {} + found_count = 0 for idx, line in enumerate(f, start_idx): try: parts = line.strip().split() @@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding): if word in vocab: index = vocab.to_index(word) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) - hit_flags[index] = True + found_count += 1 except Exception as e: if error == 'ignore': warnings.warn("Error occurred at the {} line.".format(idx)) else: print("Error occurred at the {} line.".format(idx)) raise e - found_count = sum(hit_flags) print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) - if init_method is None: - if len(vocab)-found_count>0 and found_count>0: # 有的没找到 - found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] - mean = found_vecs.mean(dim=0, keepdim=True) - std = found_vecs.std(dim=0, keepdim=True) - unfound_vec_num = np.sum(hit_flags==False) - unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean - matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs + for word, index in vocab: + if index not in matrix and not vocab._is_word_no_create_entry(word): + if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 + matrix[index] = matrix[vocab.unknown_idx] + else: + matrix[index] = None + + vectors = torch.zeros(len(matrix), dim) + if init_method: + init_method(vectors) + else: + nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim)) + + if vocab._no_create_word_length>0: + if vocab.unknown is None: # 创建一个专门的unknown + unknown_idx = len(matrix) + vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() + else: + unknown_idx = vocab.unknown_idx + words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), + requires_grad=False) + for order, (index, vec) in enumerate(matrix.items()): + if vec is not None: + vectors[order] = vec + words_to_words[index] = order + self.words_to_words = words_to_words + else: + for index, vec in matrix.items(): + if vec is not None: + vectors[index] = vec if normalize: - matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) + vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12) - return matrix, hit_flags + return vectors def forward(self, words): """ diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index 1eec7c13..76b7e922 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -35,11 +35,13 @@ class StarTransformer(nn.Module): self.iters = num_layers self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) + self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) + self.emb_drop = nn.Dropout(dropout) self.ring_att = nn.ModuleList( - [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) + [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) for _ in range(self.iters)]) self.star_att = nn.ModuleList( - [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) + [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) for _ in range(self.iters)]) if max_len is not None: @@ -66,18 +68,19 @@ class StarTransformer(nn.Module): smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 - if self.pos_emb: + if self.pos_emb and False: P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 embs = embs + P - + embs = norm_func(self.emb_drop, embs) nodes = embs relay = embs.mean(2, keepdim=True) ex_mask = mask[:, None, :, None].expand(B, H, L, 1) r_embs = embs.view(B, H, 1, L) for i in range(self.iters): ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) - nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) + nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) + #nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) nodes = nodes.masked_fill_(ex_mask, 0) diff --git a/reproduction/README.md b/reproduction/README.md index bb21c067..b6f61903 100644 --- a/reproduction/README.md +++ b/reproduction/README.md @@ -3,6 +3,8 @@ 复现的模型有: - [Star-Transformer](Star_transformer/) +- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239) +- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12) - ... # 任务复现 @@ -11,11 +13,11 @@ ## Matching (自然语言推理/句子匹配) -- still in progress +- [Matching 任务复现](matching) ## Sequence Labeling (序列标注) -- still in progress +- [NER](seqence_labelling/ner) ## Coreference resolution (指代消解) diff --git a/reproduction/Star_transformer/datasets.py b/reproduction/Star_transformer/datasets.py index a9257fd4..1173d1a0 100644 --- a/reproduction/Star_transformer/datasets.py +++ b/reproduction/Star_transformer/datasets.py @@ -2,7 +2,8 @@ import torch import json import os from fastNLP import Vocabulary -from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader +from fastNLP.io.dataset_loader import ConllLoader +from fastNLP.io.data_loader import SSTLoader, SNLILoader from fastNLP.core import Const as C import numpy as np @@ -50,13 +51,15 @@ def load_sst(path, files): for sub in [True, False, False]] ds_list = [loader.load(os.path.join(path, fn)) for fn, loader in zip(files, loaders)] - word_v = Vocabulary(min_freq=2) + word_v = Vocabulary(min_freq=0) tag_v = Vocabulary(unknown=None, padding=None) for ds in ds_list: ds.apply(lambda x: [w.lower() for w in x['words']], new_field_name='words') - ds_list[0].drop(lambda x: len(x['words']) < 3) + #ds_list[0].drop(lambda x: len(x['words']) < 3) update_v(word_v, ds_list[0], 'words') + update_v(word_v, ds_list[1], 'words') + update_v(word_v, ds_list[2], 'words') ds_list[0].apply(lambda x: tag_v.add_word( x['target']), new_field_name=None) @@ -151,7 +154,10 @@ class EmbedLoader: # some words from vocab are missing in pre-trained embedding # we normally sample each dimension vocab_embed = embedding_matrix[np.where(hit_flags)] - sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), + #sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), + # size=(len(vocab) - np.sum(hit_flags), emb_dim)) + sampled_vectors = np.random.uniform(-0.01, 0.01, size=(len(vocab) - np.sum(hit_flags), emb_dim)) + embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors return embedding_matrix diff --git a/reproduction/Star_transformer/run.sh b/reproduction/Star_transformer/run.sh index 0972c662..5cd6954b 100644 --- a/reproduction/Star_transformer/run.sh +++ b/reproduction/Star_transformer/run.sh @@ -1,5 +1,5 @@ #python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & #python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & -#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & +python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log & #python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & -python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & +#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & diff --git a/reproduction/Star_transformer/train.py b/reproduction/Star_transformer/train.py index 6fb58daf..480748df 100644 --- a/reproduction/Star_transformer/train.py +++ b/reproduction/Star_transformer/train.py @@ -1,4 +1,6 @@ from util import get_argparser, set_gpu, set_rng_seeds, add_model_args +seed = set_rng_seeds(15360) +print('RNG SEED {}'.format(seed)) from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN import torch.nn as nn import torch @@ -7,8 +9,8 @@ import fastNLP as FN from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls from fastNLP.core.const import Const as C import sys -sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') - +#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') +pre_dir = '/home/ec2-user/fast_data/' g_model_select = { 'pos': STSeqLabel, @@ -17,8 +19,8 @@ g_model_select = { 'nli': STNLICls, } -g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', - 'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} +g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt', + 'zh': pre_dir + 'cc.zh.300.vec'} g_args = None g_model_cfg = None @@ -53,7 +55,7 @@ def get_conll2012_ner(): def get_sst(): - path = '/remote-home/yfshao/workdir/datasets/SST' + path = pre_dir + 'sst' files = ['train.txt', 'dev.txt', 'test.txt'] return load_sst(path, files) @@ -94,6 +96,7 @@ class MyCallback(FN.core.callback.Callback): nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) def on_step_end(self): + return warm_steps = 6000 # learning rate warm-up & decay if self.step <= warm_steps: @@ -108,12 +111,11 @@ class MyCallback(FN.core.callback.Callback): def train(): - seed = set_rng_seeds(1234) - print('RNG SEED {}'.format(seed)) print('loading data') ds_list, word_v, tag_v = g_datasets['{}-{}'.format( g_args.ds, g_args.task)]() print(ds_list[0][:2]) + print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2])) embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') g_model_cfg['num_cls'] = len(tag_v) print(g_model_cfg) @@ -123,11 +125,14 @@ def train(): def init_model(model): for p in model.parameters(): if p.size(0) != len(word_v): - nn.init.normal_(p, 0.0, 0.05) + if len(p.size())<2: + nn.init.constant_(p, 0.0) + else: + nn.init.normal_(p, 0.0, 0.05) init_model(model) train_data = ds_list[0] - dev_data = ds_list[2] - test_data = ds_list[1] + dev_data = ds_list[1] + test_data = ds_list[2] print(tag_v.word2idx) if g_args.task in ['pos', 'ner']: @@ -145,14 +150,26 @@ def train(): } metric_key, metric = metrics[g_args.task] device = 'cuda' if torch.cuda.is_available() else 'cpu' - ex_param = [x for x in model.parameters( - ) if x.requires_grad and x.size(0) != len(word_v)] - optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, - {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] - trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, - batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, - metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, - device=device, callbacks=[MyCallback()]) + + params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)] + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + print([n for n,p in params]) + optim_cfg = [ + #{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, + {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay}, + {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay} + ] + + print(model) + trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, + loss=loss, metrics=metric, metric_key=metric_key, + optimizer=torch.optim.Adam(optim_cfg), + n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000, + device=device, + use_tqdm=False, prefetch=False, + save_path=g_args.log, + sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN), + callbacks=[MyCallback()]) trainer.train() tester = FN.Tester(data=test_data, model=model, metrics=metric, @@ -195,12 +212,12 @@ def main(): 'init_embed': (None, 300), 'num_cls': None, 'hidden_size': g_args.hidden, - 'num_layers': 4, + 'num_layers': 2, 'num_head': g_args.nhead, 'head_dim': g_args.hdim, 'max_len': MAX_LEN, - 'cls_hidden_size': 600, - 'emb_dropout': 0.3, + 'cls_hidden_size': 200, + 'emb_dropout': g_args.drop, 'dropout': g_args.drop, } run_select[g_args.mode.lower()]() diff --git a/reproduction/coreference_resolution/__init__.py b/reproduction/coreference_resolution/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/data_load/__init__.py b/reproduction/coreference_resolution/data_load/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py new file mode 100644 index 00000000..986afcd5 --- /dev/null +++ b/reproduction/coreference_resolution/data_load/cr_loader.py @@ -0,0 +1,68 @@ +from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance +from fastNLP.io.file_reader import _read_json +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.io.base_loader import DataInfo +from reproduction.coreference_resolution.model.config import Config +import reproduction.coreference_resolution.model.preprocess as preprocess + + +class CRLoader(JsonLoader): + def __init__(self, fields=None, dropna=False): + super().__init__(fields, dropna) + + def _load(self, path): + """ + 加载数据 + :param path: + :return: + """ + dataset = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + dataset.append(Instance(**ins)) + return dataset + + def process(self, paths, **kwargs): + data_info = DataInfo() + for name in ['train', 'test', 'dev']: + data_info.datasets[name] = self.load(paths[name]) + + config = Config() + vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') + vocab.build_vocab() + word2id = vocab.word2idx + + char_dict = preprocess.get_char_dict(config.char_path) + data_info.vocabs = vocab + + genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + + for name, ds in data_info.datasets.items(): + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[0], + new_field_name='doc_np') + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[1], + new_field_name='char_index') + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[2], + new_field_name='seq_len') + ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), + new_field_name='speaker_ids_np') + ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') + + ds.set_ignore_type('clusters') + ds.set_padder('clusters', None) + ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") + ds.set_target("clusters") + + # train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) + # train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) + + return data_info + + + diff --git a/reproduction/coreference_resolution/model/__init__.py b/reproduction/coreference_resolution/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/model/config.py b/reproduction/coreference_resolution/model/config.py new file mode 100644 index 00000000..6011257b --- /dev/null +++ b/reproduction/coreference_resolution/model/config.py @@ -0,0 +1,54 @@ +class Config(): + def __init__(self): + self.is_training = True + # path + self.glove = 'data/glove.840B.300d.txt.filtered' + self.turian = 'data/turian.50d.txt' + self.train_path = "data/train.english.jsonlines" + self.dev_path = "data/dev.english.jsonlines" + self.test_path = "data/test.english.jsonlines" + self.char_path = "data/char_vocab.english.txt" + + self.cuda = "0" + self.max_word = 1500 + self.epoch = 200 + + # config + # self.use_glove = True + # self.use_turian = True #No + self.use_elmo = False + self.use_CNN = True + self.model_heads = True #Yes + self.use_width = True # Yes + self.use_distance = True #Yes + self.use_metadata = True #Yes + + self.mention_ratio = 0.4 + self.max_sentences = 50 + self.span_width = 10 + self.feature_size = 20 #宽度信息emb的size + self.lr = 0.001 + self.lr_decay = 1e-3 + self.max_antecedents = 100 # 这个参数在mention detection中没有用 + self.atten_hidden_size = 150 + self.mention_hidden_size = 150 + self.sa_hidden_size = 150 + + self.char_emb_size = 8 + self.filter = [3,4,5] + + + # decay = 1e-5 + + def __str__(self): + d = self.__dict__ + out = 'config==============\n' + for i in list(d): + out += i+":" + out += str(d[i])+"\n" + out+="config==============\n" + return out + +if __name__=="__main__": + config = Config() + print(config) diff --git a/reproduction/coreference_resolution/model/metric.py b/reproduction/coreference_resolution/model/metric.py new file mode 100644 index 00000000..2c924660 --- /dev/null +++ b/reproduction/coreference_resolution/model/metric.py @@ -0,0 +1,163 @@ +from fastNLP.core.metrics import MetricBase + +import numpy as np + +from collections import Counter +from sklearn.utils.linear_assignment_ import linear_assignment + +""" +Mostly borrowed from https://github.com/clarkkev/deep-coref/blob/master/evaluation.py +""" + + + +class CRMetric(MetricBase): + def __init__(self): + super().__init__() + self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] + + # TODO 改名为evaluate,输入也 + def evaluate(self, predicted, mention_to_predicted,clusters): + for e in self.evaluators: + e.update(predicted,mention_to_predicted, clusters) + + def get_f1(self): + return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) + + def get_recall(self): + return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) + + def get_precision(self): + return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) + + # TODO 原本的getprf + def get_metric(self,reset=False): + res = {"pre":self.get_precision(), "rec":self.get_recall(), "f":self.get_f1()} + self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] + return res + + + + + + +class Evaluator(): + def __init__(self, metric, beta=1): + self.p_num = 0 + self.p_den = 0 + self.r_num = 0 + self.r_den = 0 + self.metric = metric + self.beta = beta + + def update(self, predicted,mention_to_predicted,gold): + gold = gold[0].tolist() + gold = [tuple(tuple(m) for m in gc) for gc in gold] + mention_to_gold = {} + for gc in gold: + for mention in gc: + mention_to_gold[mention] = gc + + if self.metric == ceafe: + pn, pd, rn, rd = self.metric(predicted, gold) + else: + pn, pd = self.metric(predicted, mention_to_gold) + rn, rd = self.metric(gold, mention_to_predicted) + self.p_num += pn + self.p_den += pd + self.r_num += rn + self.r_den += rd + + def get_f1(self): + return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) + + def get_recall(self): + return 0 if self.r_num == 0 else self.r_num / float(self.r_den) + + def get_precision(self): + return 0 if self.p_num == 0 else self.p_num / float(self.p_den) + + def get_prf(self): + return self.get_precision(), self.get_recall(), self.get_f1() + + def get_counts(self): + return self.p_num, self.p_den, self.r_num, self.r_den + + + +def b_cubed(clusters, mention_to_gold): + num, dem = 0, 0 + + for c in clusters: + if len(c) == 1: + continue + + gold_counts = Counter() + correct = 0 + for m in c: + if m in mention_to_gold: + gold_counts[tuple(mention_to_gold[m])] += 1 + for c2, count in gold_counts.items(): + if len(c2) != 1: + correct += count * count + + num += correct / float(len(c)) + dem += len(c) + + return num, dem + + +def muc(clusters, mention_to_gold): + tp, p = 0, 0 + for c in clusters: + p += len(c) - 1 + tp += len(c) + linked = set() + for m in c: + if m in mention_to_gold: + linked.add(mention_to_gold[m]) + else: + tp -= 1 + tp -= len(linked) + return tp, p + + +def phi4(c1, c2): + return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) + + +def ceafe(clusters, gold_clusters): + clusters = [c for c in clusters if len(c) != 1] + scores = np.zeros((len(gold_clusters), len(clusters))) + for i in range(len(gold_clusters)): + for j in range(len(clusters)): + scores[i, j] = phi4(gold_clusters[i], clusters[j]) + matching = linear_assignment(-scores) + similarity = sum(scores[matching[:, 0], matching[:, 1]]) + return similarity, len(clusters), similarity, len(gold_clusters) + + +def lea(clusters, mention_to_gold): + num, dem = 0, 0 + + for c in clusters: + if len(c) == 1: + continue + + common_links = 0 + all_links = len(c) * (len(c) - 1) / 2.0 + for i, m in enumerate(c): + if m in mention_to_gold: + for m2 in c[i + 1:]: + if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: + common_links += 1 + + num += len(c) * common_links / float(all_links) + dem += len(c) + + return num, dem + +def f1(p_num, p_den, r_num, r_den, beta=1): + p = 0 if p_den == 0 else p_num / float(p_den) + r = 0 if r_den == 0 else r_num / float(r_den) + return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) diff --git a/reproduction/coreference_resolution/model/model_re.py b/reproduction/coreference_resolution/model/model_re.py new file mode 100644 index 00000000..9dd90ec4 --- /dev/null +++ b/reproduction/coreference_resolution/model/model_re.py @@ -0,0 +1,576 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from allennlp.commands.elmo import ElmoEmbedder +from fastNLP.models.base_model import BaseModel +from fastNLP.modules.encoder.variational_rnn import VarLSTM +from reproduction.coreference_resolution.model import preprocess +from fastNLP.io.embed_loader import EmbedLoader +import random + +# 设置seed +torch.manual_seed(0) # cpu +torch.cuda.manual_seed(0) # gpu +np.random.seed(0) # numpy +random.seed(0) + + +class ffnn(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(ffnn, self).__init__() + + self.f = nn.Sequential( + # 多少层数 + nn.Linear(input_size, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.2), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.2), + nn.Linear(hidden_size, output_size) + ) + self.reset_param() + + def reset_param(self): + for name, param in self.named_parameters(): + if param.dim() > 1: + nn.init.xavier_normal_(param) + # param.data = torch.tensor(np.random.randn(*param.shape)).float() + else: + nn.init.zeros_(param) + + def forward(self, input): + return self.f(input).squeeze() + + +class Model(BaseModel): + def __init__(self, vocab, config): + word2id = vocab.word2idx + super(Model, self).__init__() + vocab_num = len(word2id) + self.word2id = word2id + self.config = config + self.char_dict = preprocess.get_char_dict('data/char_vocab.english.txt') + self.genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + self.device = torch.device("cuda:" + config.cuda) + + self.emb = nn.Embedding(vocab_num, 350) + + emb1 = EmbedLoader().load_with_vocab(config.glove, vocab,normalize=False) + emb2 = EmbedLoader().load_with_vocab(config.turian, vocab ,normalize=False) + pre_emb = np.concatenate((emb1, emb2), axis=1) + pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12) + + if pre_emb is not None: + self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float()) + for param in self.emb.parameters(): + param.requires_grad = False + self.emb_dropout = nn.Dropout(inplace=True) + + + if config.use_elmo: + self.elmo = ElmoEmbedder(options_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json', + weight_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', + cuda_device=int(config.cuda)) + print("elmo load over.") + self.elmo_args = torch.randn((3), requires_grad=True).to(self.device) + + self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size) + self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3) + self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4) + self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5) + + self.feature_emb = nn.Embedding(config.span_width, config.feature_size) + self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True) + + self.mention_distance_emb = nn.Embedding(10, config.feature_size) + self.distance_drop = nn.Dropout(p=0.2, inplace=True) + + self.genre_emb = nn.Embedding(7, config.feature_size) + self.speaker_emb = nn.Embedding(2, config.feature_size) + + self.bilstm = VarLSTM(input_size=350+150*config.use_CNN+config.use_elmo*1024,hidden_size=200,bidirectional=True,batch_first=True,hidden_dropout=0.2) + # self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True) + self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) + self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) + self.bilstm_drop = nn.Dropout(p=0.2, inplace=True) + + self.atten = ffnn(input_size=400, hidden_size=config.atten_hidden_size, output_size=1) + self.mention_score = ffnn(input_size=1320, hidden_size=config.mention_hidden_size, output_size=1) + self.sa = ffnn(input_size=3980+40*config.use_metadata, hidden_size=config.sa_hidden_size, output_size=1) + self.mention_start_np = None + self.mention_end_np = None + + def _reorder_lstm(self, word_emb, seq_lens): + sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True) + seq_lens_re = [seq_lens[i] for i in sort_ind] + emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True) + packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq, seq_lens_re, batch_first=True) + + h0 = self.h0.repeat(1, len(seq_lens), 1) + c0 = self.c0.repeat(1, len(seq_lens), 1) + packed_out, final_states = self.bilstm(packed_seq, (h0, c0)) + + lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) + back_map = {ind: i for i, ind in enumerate(sort_ind)} + reorder_ind = [back_map[i] for i in range(len(seq_lens_re))] + lstm_out = self.reorder_sequence(lstm_out, reorder_ind, batch_first=True) + return lstm_out + + def reorder_sequence(self, sequence_emb, order, batch_first=True): + """ + sequence_emb: [T, B, D] if not batch_first + order: list of sequence length + """ + batch_dim = 0 if batch_first else 1 + assert len(order) == sequence_emb.size()[batch_dim] + + order = torch.LongTensor(order) + order = order.to(sequence_emb).long() + + sorted_ = sequence_emb.index_select(index=order, dim=batch_dim) + + del order + return sorted_ + + def flat_lstm(self, lstm_out, seq_lens): + batch = lstm_out.shape[0] + seq = lstm_out.shape[1] + dim = lstm_out.shape[2] + l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)] + flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device)) + return flatted + + def potential_mention_index(self, word_index, max_sent_len): + # get mention index [3,2]:the first sentence is 3 and secend 2 + # [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2) + potential_mention = [] + for i in range(len(word_index)): + for j in range(i, i + max_sent_len): + if (j < len(word_index) and word_index[i] == word_index[j]): + potential_mention.append([i, j]) + return potential_mention + + def get_mention_start_end(self, seq_lens): + # 序列长度转换成mention + # [3,2] --> [0,0,0,1,1] + word_index = [0] * sum(seq_lens) + sent_index = 0 + index = 0 + for length in seq_lens: + for l in range(length): + word_index[index] = sent_index + index += 1 + sent_index += 1 + + # [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....] + mention_id = self.potential_mention_index(word_index, self.config.span_width) + mention_start = np.array(mention_id, dtype=int)[:, 0] + mention_end = np.array(mention_id, dtype=int)[:, 1] + return mention_start, mention_end + + def get_mention_emb(self, flatten_lstm, mention_start, mention_end): + mention_start_tensor = torch.from_numpy(mention_start).to(self.device) + mention_end_tensor = torch.from_numpy(mention_end).to(self.device) + emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor) # [mention_num,embed] + emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor) # [mention_num,embed] + return emb_start, emb_end + + def get_mask(self, mention_start, mention_end): + # big mask for attention + mention_num = mention_start.shape[0] + mask = np.zeros((mention_num, self.config.span_width)) # [mention_num,span_width] + for i in range(mention_num): + start = mention_start[i] + end = mention_end[i] + # 实际上是宽度 + for j in range(end - start + 1): + mask[i][j] = 1 + mask = torch.from_numpy(mask) # [mention_num,max_mention] + # 0-->-inf 1-->0 + log_mask = torch.log(mask) + return log_mask + + def get_mention_index(self, mention_start, max_mention): + # TODO 后面可能要改 + assert len(mention_start.shape) == 1 + mention_start_tensor = torch.from_numpy(mention_start) + num_mention = mention_start_tensor.shape[0] + mention_index = mention_start_tensor.expand(max_mention, num_mention).transpose(0, + 1) # [num_mention,max_mention] + assert mention_index.shape[0] == num_mention + assert mention_index.shape[1] == max_mention + range_add = torch.arange(0, max_mention).expand(num_mention, max_mention).long() # [num_mention,max_mention] + mention_index = mention_index + range_add + mention_index = torch.min(mention_index, torch.LongTensor([mention_start[-1]]).expand(num_mention, max_mention)) + return mention_index.to(self.device) + + def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens): + # 排序记录,高分段在前面 + mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True) + preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens)) + mention_ids = mention_ids[0:preserve_mention_num] + mention_score = mention_score[0:preserve_mention_num] + + mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0, + index=mention_ids) # [lamda*word_num] + mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0, + index=mention_ids) # [lamda*word_num] + mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0) # [lamda*word_num,emb] + assert mention_score.shape[0] == preserve_mention_num + assert mention_start_tensor.shape[0] == preserve_mention_num + assert mention_end_tensor.shape[0] == preserve_mention_num + assert mention_emb.shape[0] == preserve_mention_num + # TODO 不交叉没做处理 + + # 对start进行再排序,实际位置在前面 + # TODO 这里只考虑了start没有考虑end + mention_start_tensor, temp_index = torch.sort(mention_start_tensor) + mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index) + mention_emb = mention_emb.index_select(dim=0, index=temp_index) + mention_score = mention_score.index_select(dim=0, index=temp_index) + return mention_start_tensor, mention_end_tensor, mention_score, mention_emb + + def get_antecedents(self, mention_starts, max_antecedents): + num_mention = mention_starts.shape[0] + max_antecedents = min(max_antecedents, num_mention) + # mention和它是第几个mention之间的对应关系 + antecedents = np.zeros((num_mention, max_antecedents), dtype=int) # [num_mention,max_an] + # 记录长度 + antecedents_len = [0] * num_mention + for i in range(num_mention): + ante_count = 0 + for j in range(max(0, i - max_antecedents), i): + antecedents[i, ante_count] = j + ante_count += 1 + # 补位操作 + for j in range(ante_count, max_antecedents): + antecedents[i, j] = 0 + antecedents_len[i] = ante_count + assert antecedents.shape[1] == max_antecedents + return antecedents, antecedents_len + + def get_antecedents_score(self, span_represent, mention_score, antecedents, antecedents_len, mention_speakers_ids, + genre): + num_mention = mention_score.shape[0] + max_antecedent = antecedents.shape[1] + + pair_emb = self.get_pair_emb(span_represent, antecedents, mention_speakers_ids, genre) # [span_num,max_ant,emb] + antecedent_scores = self.sa(pair_emb) + mask01 = self.sequence_mask(antecedents_len, max_antecedent) + maskinf = torch.log(mask01).to(self.device) + assert maskinf.shape[1] <= max_antecedent + assert antecedent_scores.shape[0] == num_mention + antecedent_scores = antecedent_scores + maskinf + antecedents = torch.from_numpy(antecedents).to(self.device) + mention_scoreij = mention_score.unsqueeze(1) + torch.gather( + mention_score.unsqueeze(0).expand(num_mention, num_mention), dim=1, index=antecedents) + antecedent_scores += mention_scoreij + + antecedent_scores = torch.cat([torch.zeros([mention_score.shape[0], 1]).to(self.device), antecedent_scores], + 1) # [num_mentions, max_ant + 1] + return antecedent_scores + + ############################## + def distance_bin(self, mention_distance): + bins = torch.zeros(mention_distance.size()).byte().to(self.device) + rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]] + for t, k in enumerate(rg): + i, j = k[0], k[1] + b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) + m1 = torch.ge(mention_distance, b) + e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) + m2 = torch.le(mention_distance, e) + bins = bins + (t + 1) * (m1 & m2) + return bins.long() + + def get_distance_emb(self, antecedents_tensor): + num_mention = antecedents_tensor.shape[0] + max_ant = antecedents_tensor.shape[1] + + assert max_ant <= self.config.max_antecedents + source = torch.arange(0, num_mention).expand(max_ant, num_mention).transpose(0,1).to(self.device) # [num_mention,max_ant] + mention_distance = source - antecedents_tensor + mention_distance_bin = self.distance_bin(mention_distance) + distance_emb = self.mention_distance_emb(mention_distance_bin) + distance_emb = self.distance_drop(distance_emb) + return distance_emb + + def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre): + emb_dim = span_emb.shape[1] + num_span = span_emb.shape[0] + max_ant = antecedents.shape[1] + assert span_emb.shape[0] == antecedents.shape[0] + antecedents = torch.from_numpy(antecedents).to(self.device) + + # [num_span,max_ant,emb] + antecedent_emb = torch.gather(span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim), dim=1, + index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim)) + # [num_span,max_ant,emb] + target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim)) + target_emb_tiled = target_emb_tiled.transpose(0, 1) + + similarity_emb = antecedent_emb * target_emb_tiled + + pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb] + + # get speakers and genre + if self.config.use_metadata: + antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(num_span, num_span).gather(dim=1, + index=antecedents) + same_speaker = torch.eq(mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant), + antecedent_speaker_ids) # [num_mention,max_ant] + speaker_embedding = self.speaker_emb(same_speaker.long().to(self.device)) # [mention_num.max_ant,emb] + genre_embedding = self.genre_emb( + torch.LongTensor([genre]).expand(num_span, max_ant).to(self.device)) # [mention_num,max_ant,emb] + pair_emb_list.append(speaker_embedding) + pair_emb_list.append(genre_embedding) + + # get distance emb + if self.config.use_distance: + distance_emb = self.get_distance_emb(antecedents) + pair_emb_list.append(distance_emb) + + pair_emb = torch.cat(pair_emb_list, 2) + return pair_emb + + def sequence_mask(self, len_list, max_len): + x = np.zeros((len(len_list), max_len)) + for i in range(len(len_list)): + l = len_list[i] + for j in range(l): + x[i][j] = 1 + return torch.from_numpy(x).float() + + def logsumexp(self, value, dim=None, keepdim=False): + """Numerically stable implementation of the operation + + value.exp().sum(dim, keepdim).log() + """ + # TODO: torch.max(value, dim=None) threw an error at time of writing + if dim is not None: + m, _ = torch.max(value, dim=dim, keepdim=True) + value0 = value - m + if keepdim is False: + m = m.squeeze(dim) + return m + torch.log(torch.sum(torch.exp(value0), + dim=dim, keepdim=keepdim)) + else: + m = torch.max(value) + sum_exp = torch.sum(torch.exp(value - m)) + + return m + torch.log(sum_exp) + + def softmax_loss(self, antecedent_scores, antecedent_labels): + antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(self.device) + gold_scores = antecedent_scores + torch.log(antecedent_labels.float()) # [num_mentions, max_ant + 1] + marginalized_gold_scores = self.logsumexp(gold_scores, 1) # [num_mentions] + log_norm = self.logsumexp(antecedent_scores, 1) # [num_mentions] + return torch.sum(log_norm - marginalized_gold_scores) # [num_mentions]reduce_logsumexp + + def get_predicted_antecedents(self, antecedents, antecedent_scores): + predicted_antecedents = [] + for i, index in enumerate(np.argmax(antecedent_scores.detach(), axis=1) - 1): + if index < 0: + predicted_antecedents.append(-1) + else: + predicted_antecedents.append(antecedents[i, index]) + return predicted_antecedents + + def get_predicted_clusters(self, mention_starts, mention_ends, predicted_antecedents): + mention_to_predicted = {} + predicted_clusters = [] + for i, predicted_index in enumerate(predicted_antecedents): + if predicted_index < 0: + continue + assert i > predicted_index + predicted_antecedent = (int(mention_starts[predicted_index]), int(mention_ends[predicted_index])) + if predicted_antecedent in mention_to_predicted: + predicted_cluster = mention_to_predicted[predicted_antecedent] + else: + predicted_cluster = len(predicted_clusters) + predicted_clusters.append([predicted_antecedent]) + mention_to_predicted[predicted_antecedent] = predicted_cluster + + mention = (int(mention_starts[i]), int(mention_ends[i])) + predicted_clusters[predicted_cluster].append(mention) + mention_to_predicted[mention] = predicted_cluster + + predicted_clusters = [tuple(pc) for pc in predicted_clusters] + mention_to_predicted = {m: predicted_clusters[i] for m, i in mention_to_predicted.items()} + + return predicted_clusters, mention_to_predicted + + def evaluate_coref(self, mention_starts, mention_ends, predicted_antecedents, gold_clusters, evaluator): + gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters] + mention_to_gold = {} + for gc in gold_clusters: + for mention in gc: + mention_to_gold[mention] = gc + predicted_clusters, mention_to_predicted = self.get_predicted_clusters(mention_starts, mention_ends, + predicted_antecedents) + evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) + return predicted_clusters + + + def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): + """ + 实际输入都是tensor + :param sentences: 句子,被fastNLP转化成了numpy, + :param doc_np: 被fastNLP转化成了Tensor + :param speaker_ids_np: 被fastNLP转化成了Tensor + :param genre: 被fastNLP转化成了Tensor + :param char_index: 被fastNLP转化成了Tensor + :param seq_len: 被fastNLP转化成了Tensor + :return: + """ + # change for fastNLP + sentences = sentences[0].tolist() + doc_tensor = doc_np[0] + speakers_tensor = speaker_ids_np[0] + genre = genre[0].item() + char_index = char_index[0] + seq_len = seq_len[0].cpu().numpy() + + # 类型 + + # doc_tensor = torch.from_numpy(doc_np).to(self.device) + # speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device) + mention_emb_list = [] + + word_emb = self.emb(doc_tensor) + word_emb_list = [word_emb] + if self.config.use_CNN: + # [batch, length, char_length, char_dim] + char = self.char_emb(char_index) + char_size = char.size() + # first transform to [batch *length, char_length, char_dim] + # then transpose to [batch * length, char_dim, char_length] + char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2) + + # put into cnn [batch*length, char_filters, char_length] + # then put into maxpooling [batch * length, char_filters] + char_over_cnn, _ = self.conv1(char).max(dim=2) + # reshape to [batch, length, char_filters] + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + char_over_cnn, _ = self.conv2(char).max(dim=2) + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + char_over_cnn, _ = self.conv3(char).max(dim=2) + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + # word_emb = torch.cat(word_emb_list, dim=2) + + # use elmo or not + if self.config.use_elmo: + # 如果确实被截断了 + if doc_tensor.shape[0] == 50 and len(sentences) > 50: + sentences = sentences[0:50] + elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(sentences) + elmo_embedding = elmo_embedding.to( + self.device) # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024] + elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \ + self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2] + word_emb_list.append(elmo_embedding) + # print(word_emb_list[0].shape) + # print(word_emb_list[1].shape) + # print(word_emb_list[2].shape) + # print(word_emb_list[3].shape) + # print(word_emb_list[4].shape) + + word_emb = torch.cat(word_emb_list, dim=2) + + word_emb = self.emb_dropout(word_emb) + # word_emb_elmo = self.emb_dropout(word_emb_elmo) + lstm_out = self._reorder_lstm(word_emb, seq_len) + flatten_lstm = self.flat_lstm(lstm_out, seq_len) # [word_num,emb] + flatten_lstm = self.bilstm_drop(flatten_lstm) + # TODO 没有按照论文写 + flatten_word_emb = self.flat_lstm(word_emb, seq_len) # [word_num,emb] + + mention_start, mention_end = self.get_mention_start_end(seq_len) # [mention_num] + self.mention_start_np = mention_start # [mention_num] np + self.mention_end_np = mention_end + mention_num = mention_start.shape[0] + emb_start, emb_end = self.get_mention_emb(flatten_lstm, mention_start, mention_end) # [mention_num,emb] + + # list + mention_emb_list.append(emb_start) + mention_emb_list.append(emb_end) + + if self.config.use_width: + mention_width_index = mention_end - mention_start + mention_width_tensor = torch.from_numpy(mention_width_index).to(self.device) # [mention_num] + mention_width_emb = self.feature_emb(mention_width_tensor) + mention_width_emb = self.feature_emb_dropout(mention_width_emb) + mention_emb_list.append(mention_width_emb) + + if self.config.model_heads: + mention_index = self.get_mention_index(mention_start, self.config.span_width) # [mention_num,max_mention] + log_mask_tensor = self.get_mask(mention_start, mention_end).float().to( + self.device) # [mention_num,max_mention] + alpha = self.atten(flatten_lstm).to(self.device) # [word_num] + + # 得到attention + mention_head_score = torch.gather(alpha.expand(mention_num, -1), 1, + mention_index).float().to(self.device) # [mention_num,max_mention] + mention_attention = F.softmax(mention_head_score + log_mask_tensor, dim=1) # [mention_num,max_mention] + + # TODO flatte lstm + word_num = flatten_lstm.shape[0] + lstm_emb = flatten_lstm.shape[1] + emb_num = flatten_word_emb.shape[1] + + # [num_mentions, max_mention_width, emb] + mention_text_emb = torch.gather( + flatten_word_emb.unsqueeze(1).expand(word_num, self.config.span_width, emb_num), + 0, mention_index.unsqueeze(2).expand(mention_num, self.config.span_width, + emb_num)) + # [mention_num,emb] + mention_head_emb = torch.sum( + mention_attention.unsqueeze(2).expand(mention_num, self.config.span_width, emb_num) * mention_text_emb, + dim=1) + mention_emb_list.append(mention_head_emb) + + candidate_mention_emb = torch.cat(mention_emb_list, 1) # [candidate_mention_num,emb] + candidate_mention_score = self.mention_score(candidate_mention_emb) # [candidate_mention_num] + + antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (None, None, None, None) + mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \ + self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len) + mention_speakers_ids = speakers_tensor.index_select(dim=0, index=mention_start_tensor) # num_mention + + antecedents, antecedents_len = self.get_antecedents(mention_start_tensor, self.config.max_antecedents) + antecedent_scores = self.get_antecedents_score(mention_emb, mention_score, antecedents, antecedents_len, + mention_speakers_ids, genre) + + ans = {"candidate_mention_score": candidate_mention_score, "antecedent_scores": antecedent_scores, + "antecedents": antecedents, "mention_start_tensor": mention_start_tensor, + "mention_end_tensor": mention_end_tensor} + + return ans + + def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): + ans = self(sentences, + doc_np, + speaker_ids_np, + genre, + char_index, + seq_len) + + predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"]) + predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"], + ans["mention_end_tensor"], + predicted_antecedents) + + return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} + + +if __name__ == '__main__': + pass diff --git a/reproduction/coreference_resolution/model/preprocess.py b/reproduction/coreference_resolution/model/preprocess.py new file mode 100644 index 00000000..d97fcb4d --- /dev/null +++ b/reproduction/coreference_resolution/model/preprocess.py @@ -0,0 +1,225 @@ +import json +import numpy as np +from . import util +import collections + +def load(path): + """ + load the file from jsonline + :param path: + :return: examples with many example(dict): {"clusters":[[[mention],[mention]],[another cluster]], + "doc_key":"str","speakers":[[,,,],[]...],"sentence":[[][]]} + """ + with open(path) as f: + train_examples = [json.loads(jsonline) for jsonline in f.readlines()] + return train_examples + +def get_vocab(): + """ + 从所有的句子中得到最终的字典,被main调用,不止是train,还有dev和test + :param examples: + :return: word2id & id2word + """ + word2id = {'PAD':0,'UNK':1} + id2word = {0:'PAD',1:'UNK'} + index = 2 + data = [load("../data/train.english.jsonlines"),load("../data/dev.english.jsonlines"),load("../data/test.english.jsonlines")] + for examples in data: + for example in examples: + for sent in example["sentences"]: + for word in sent: + if(word not in word2id): + word2id[word]=index + id2word[index] = word + index += 1 + return word2id,id2word + +def normalize(v): + norm = np.linalg.norm(v) + if norm > 0: + return v / norm + else: + return v + +# 加载glove得到embedding +def get_emb(id2word,embedding_size): + glove_oov = 0 + turian_oov = 0 + both = 0 + glove_emb_path = "../data/glove.840B.300d.txt.filtered" + turian_emb_path = "../data/turian.50d.txt" + word_num = len(id2word) + emb = np.zeros((word_num,embedding_size)) + glove_emb_dict = util.load_embedding_dict(glove_emb_path,300,"txt") + turian_emb_dict = util.load_embedding_dict(turian_emb_path,50,"txt") + for i in range(word_num): + if id2word[i] in glove_emb_dict: + word_embedding = glove_emb_dict.get(id2word[i]) + emb[i][0:300] = np.array(word_embedding) + else: + # print(id2word[i]) + glove_oov += 1 + if id2word[i] in turian_emb_dict: + word_embedding = turian_emb_dict.get(id2word[i]) + emb[i][300:350] = np.array(word_embedding) + else: + # print(id2word[i]) + turian_oov += 1 + if id2word[i] not in glove_emb_dict and id2word[i] not in turian_emb_dict: + both += 1 + emb[i] = normalize(emb[i]) + print("embedding num:"+str(word_num)) + print("glove num:"+str(glove_oov)) + print("glove oov rate:"+str(glove_oov/word_num)) + print("turian num:"+str(turian_oov)) + print("turian oov rate:"+str(turian_oov/word_num)) + print("both num:"+str(both)) + return emb + + +def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): + max_len = 0 + max_word_length = 0 + docvex = [] + length = [] + if is_train: + sent_num = min(max_sentences,len(doc)) + else: + sent_num = len(doc) + + for i in range(sent_num): + sent = doc[i] + length.append(len(sent)) + if (len(sent) > max_len): + max_len = len(sent) + sent_vec =[] + for j,word in enumerate(sent): + if len(word)>max_word_length: + max_word_length = len(word) + if word in word2id: + sent_vec.append(word2id[word]) + else: + sent_vec.append(word2id["UNK"]) + docvex.append(sent_vec) + + char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) + for i in range(sent_num): + sent = doc[i] + for j,word in enumerate(sent): + char_index[i, j, :len(word)] = [char_dict[c] for c in word] + + return docvex,char_index,length,max_len + +# TODO 修改了接口,确认所有该修改的地方都修改好 +def doc2numpy(doc,word2id,chardict,max_filter,max_sentences,is_train): + docvec, char_index, length, max_len = _doc2vec(doc,word2id,chardict,max_filter,max_sentences,is_train) + assert max(length) == max_len + assert char_index.shape[0]==len(length) + assert char_index.shape[1]==max_len + doc_np = np.zeros((len(docvec), max_len), int) + for i in range(len(docvec)): + for j in range(len(docvec[i])): + doc_np[i][j] = docvec[i][j] + return doc_np,char_index,length + +# TODO 没有测试 +def speaker2numpy(speakers_raw,max_sentences,is_train): + if is_train and len(speakers_raw)> max_sentences: + speakers_raw = speakers_raw[0:max_sentences] + speakers = flatten(speakers_raw) + speaker_dict = {s: i for i, s in enumerate(set(speakers))} + speaker_ids = np.array([speaker_dict[s] for s in speakers]) + return speaker_ids + + +def flat_cluster(clusters): + flatted = [] + for cluster in clusters: + for item in cluster: + flatted.append(item) + return flatted + +def get_right_mention(clusters,mention_start_np,mention_end_np): + flatted = flat_cluster(clusters) + cluster_num = len(flatted) + mention_num = mention_start_np.shape[0] + right_mention = np.zeros(mention_num,dtype=int) + for i in range(mention_num): + if [mention_start_np[i],mention_end_np[i]] in flatted: + right_mention[i]=1 + return right_mention,cluster_num + +def handle_cluster(clusters): + gold_mentions = sorted(tuple(m) for m in flatten(clusters)) + gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} + cluster_ids = np.zeros(len(gold_mentions), dtype=int) + for cluster_id, cluster in enumerate(clusters): + for mention in cluster: + cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + gold_starts, gold_ends = tensorize_mentions(gold_mentions) + return cluster_ids, gold_starts, gold_ends + +# 展平 +def flatten(l): + return [item for sublist in l for item in sublist] + +# 把mention分成start end +def tensorize_mentions(mentions): + if len(mentions) > 0: + starts, ends = zip(*mentions) + else: + starts, ends = [], [] + return np.array(starts), np.array(ends) + +def get_char_dict(path): + vocab = [""] + with open(path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict + +def get_labels(clusters,mention_starts,mention_ends,max_antecedents): + cluster_ids, gold_starts, gold_ends = handle_cluster(clusters) + num_mention = mention_starts.shape[0] + num_gold = gold_starts.shape[0] + max_antecedents = min(max_antecedents, num_mention) + mention_indices = {} + + for i in range(num_mention): + mention_indices[(mention_starts[i].detach().item(), mention_ends[i].detach().item())] = i + # 用来记录哪些mention是对的,-1表示错误,正数代表这个mention实际上对应哪个gold cluster的id + mention_cluster_ids = [-1] * num_mention + # test + right_mention_count = 0 + for i in range(num_gold): + right_mention = mention_indices.get((gold_starts[i], gold_ends[i])) + if (right_mention != None): + right_mention_count += 1 + mention_cluster_ids[right_mention] = cluster_ids[i] + + # i j 是否属于同一个cluster + labels = np.zeros((num_mention, max_antecedents + 1), dtype=bool) # [num_mention,max_an+1] + for i in range(num_mention): + ante_count = 0 + null_label = True + for j in range(max(0, i - max_antecedents), i): + if (mention_cluster_ids[i] >= 0 and mention_cluster_ids[i] == mention_cluster_ids[j]): + labels[i, ante_count + 1] = True + null_label = False + else: + labels[i, ante_count + 1] = False + ante_count += 1 + for j in range(ante_count, max_antecedents): + labels[i, j + 1] = False + labels[i, 0] = null_label + return labels + +# test=========================== + + +if __name__=="__main__": + word2id,id2word = get_vocab() + get_emb(id2word,350) + + diff --git a/reproduction/coreference_resolution/model/softmax_loss.py b/reproduction/coreference_resolution/model/softmax_loss.py new file mode 100644 index 00000000..c75a31d6 --- /dev/null +++ b/reproduction/coreference_resolution/model/softmax_loss.py @@ -0,0 +1,32 @@ +from fastNLP.core.losses import LossBase + +from reproduction.coreference_resolution.model.preprocess import get_labels +from reproduction.coreference_resolution.model.config import Config +import torch + + +class SoftmaxLoss(LossBase): + """ + 交叉熵loss + 允许多标签分类 + """ + + def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None): + """ + + :param pred: + :param target: + """ + super().__init__() + self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters, + mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) + + def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor): + antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor, + Config().max_antecedents) + + antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) + gold_scores = antecedent_scores + torch.log(antecedent_labels.float()).to(torch.device("cuda:" + Config().cuda)) # [num_mentions, max_ant + 1] + marginalized_gold_scores = gold_scores.logsumexp(dim=1) # [num_mentions] + log_norm = antecedent_scores.logsumexp(dim=1) # [num_mentions] + return torch.sum(log_norm - marginalized_gold_scores) diff --git a/reproduction/coreference_resolution/model/util.py b/reproduction/coreference_resolution/model/util.py new file mode 100644 index 00000000..42cd09fe --- /dev/null +++ b/reproduction/coreference_resolution/model/util.py @@ -0,0 +1,101 @@ +import os +import errno +import collections +import torch +import numpy as np +import pyhocon + + + +# flatten the list +def flatten(l): + return [item for sublist in l for item in sublist] + + +def get_config(filename): + return pyhocon.ConfigFactory.parse_file(filename) + + +# safe make directions +def mkdirs(path): + try: + os.makedirs(path) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + return path + + +def load_char_dict(char_vocab_path): + vocab = [""] + with open(char_vocab_path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict + +# 加载embedding +def load_embedding_dict(embedding_path, embedding_size, embedding_format): + print("Loading word embeddings from {}...".format(embedding_path)) + default_embedding = np.zeros(embedding_size) + embedding_dict = collections.defaultdict(lambda: default_embedding) + skip_first = embedding_format == "vec" + with open(embedding_path) as f: + for i, line in enumerate(f.readlines()): + if skip_first and i == 0: + continue + splits = line.split() + assert len(splits) == embedding_size + 1 + word = splits[0] + embedding = np.array([float(s) for s in splits[1:]]) + embedding_dict[word] = embedding + print("Done loading word embeddings.") + return embedding_dict + + +# safe devide +def maybe_divide(x, y): + return 0 if y == 0 else x / float(y) + + +def shape(x, dim): + return x.get_shape()[dim].value or torch.shape(x)[dim] + + +def normalize(v): + norm = np.linalg.norm(v) + if norm > 0: + return v / norm + else: + return v + + +class RetrievalEvaluator(object): + def __init__(self): + self._num_correct = 0 + self._num_gold = 0 + self._num_predicted = 0 + + def update(self, gold_set, predicted_set): + self._num_correct += len(gold_set & predicted_set) + self._num_gold += len(gold_set) + self._num_predicted += len(predicted_set) + + def recall(self): + return maybe_divide(self._num_correct, self._num_gold) + + def precision(self): + return maybe_divide(self._num_correct, self._num_predicted) + + def metrics(self): + recall = self.recall() + precision = self.precision() + f1 = maybe_divide(2 * recall * precision, precision + recall) + return recall, precision, f1 + + + +if __name__=="__main__": + print(load_char_dict("../data/char_vocab.english.txt")) + embedding_dict = load_embedding_dict("../data/glove.840B.300d.txt.filtered",300,"txt") + print("hello") diff --git a/reproduction/coreference_resolution/readme.md b/reproduction/coreference_resolution/readme.md new file mode 100644 index 00000000..67d8cdc7 --- /dev/null +++ b/reproduction/coreference_resolution/readme.md @@ -0,0 +1,49 @@ +# 共指消解复现 +## 介绍 +Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 +对于涉及自然语言理解的许多更高级别的NLP任务来说, +这是一个重要的步骤,例如文档摘要,问题回答和信息提取。 +代码的实现主要基于[ End-to-End Coreference Resolution (Lee et al, 2017)](https://arxiv.org/pdf/1707.07045). + + +## 数据获取与预处理 +论文在[OntoNote5.0](https://allennlp.org/models)数据集上取得了当时的sota结果。 +由于版权问题,本文无法提供数据集的下载,请自行下载。 +原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 + +代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 +处理之后的数据集为json格式,例子: +``` +{ + "clusters": [], + "doc_key": "nw", + "sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]], + "speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]] +} +``` + +### embedding 数据集下载 +[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) + +[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) + + + +## 运行 +```python +# 训练代码 +CUDA_VISIBLE_DEVICES=0 python train.py +# 测试代码 +CUDA_VISIBLE_DEVICES=0 python valid.py +``` + +## 结果 +原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 +其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 + +在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 + + +## 问题 +如果您有什么问题或者反馈,请提issue或者邮件联系我: +yexu_i@qq.com diff --git a/reproduction/coreference_resolution/test/__init__.py b/reproduction/coreference_resolution/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/test/test_dataloader.py b/reproduction/coreference_resolution/test/test_dataloader.py new file mode 100644 index 00000000..0d9dae52 --- /dev/null +++ b/reproduction/coreference_resolution/test/test_dataloader.py @@ -0,0 +1,14 @@ +import unittest +from ..data_load.cr_loader import CRLoader + +class Test_CRLoader(unittest.TestCase): + def test_cr_loader(self): + train_path = 'data/train.english.jsonlines.mini' + dev_path = 'data/dev.english.jsonlines.minid' + test_path = 'data/test.english.jsonlines' + cr = CRLoader() + data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) + + print(data_info.datasets['train'][0]) + print(data_info.datasets['dev'][0]) + print(data_info.datasets['test'][0]) diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py new file mode 100644 index 00000000..a231a575 --- /dev/null +++ b/reproduction/coreference_resolution/train.py @@ -0,0 +1,69 @@ +import sys +sys.path.append('../..') + +import torch +from torch.optim import Adam + +from fastNLP.core.callback import Callback, GradientClipCallback +from fastNLP.core.trainer import Trainer + +from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from reproduction.coreference_resolution.model.config import Config +from reproduction.coreference_resolution.model.model_re import Model +from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss +from reproduction.coreference_resolution.model.metric import CRMetric +from fastNLP import SequentialSampler +from fastNLP import cache_results + + +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.deterministic = True + +class LRCallback(Callback): + def __init__(self, parameters, decay_rate=1e-3): + super().__init__() + self.paras = parameters + self.decay_rate = decay_rate + + def on_step_end(self): + if self.step % 100 == 0: + for para in self.paras: + para['lr'] = para['lr'] * (1 - self.decay_rate) + + +if __name__ == "__main__": + config = Config() + + print(config) + + @cache_results('cache.pkl') + def cache(): + cr_train_dev_test = CRLoader() + + data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, + 'test': config.test_path}) + return data_info + data_info = cache() + print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), + "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) + # print(data_info) + model = Model(data_info.vocabs, config) + print(model) + + loss = SoftmaxLoss() + + metric = CRMetric() + + optim = Adam(model.parameters(), lr=config.lr) + + lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) + + trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], + loss=loss, metrics=metric, check_code_level=-1,sampler=None, + batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, + optimizer=optim, + save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', + callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) + print() + + trainer.train() diff --git a/reproduction/coreference_resolution/valid.py b/reproduction/coreference_resolution/valid.py new file mode 100644 index 00000000..826332c6 --- /dev/null +++ b/reproduction/coreference_resolution/valid.py @@ -0,0 +1,24 @@ +import torch +from reproduction.coreference_resolution.model.config import Config +from reproduction.coreference_resolution.model.metric import CRMetric +from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from fastNLP import Tester +import argparse + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--path') + args = parser.parse_args() + + cr_loader = CRLoader() + config = Config() + data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, + 'test': config.test_path}) + metirc = CRMetric() + model = torch.load(args.path) + tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") + tester.test() + print('test over') + + diff --git a/reproduction/joint_cws_parse/__init__.py b/reproduction/joint_cws_parse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/data/__init__.py b/reproduction/joint_cws_parse/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py new file mode 100644 index 00000000..7802ea09 --- /dev/null +++ b/reproduction/joint_cws_parse/data/data_loader.py @@ -0,0 +1,284 @@ + + +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.dataset_loader import ConllLoader +import numpy as np + +from itertools import chain +from fastNLP import DataSet, Vocabulary +from functools import partial +import os +from typing import Union, Dict +from reproduction.utils import check_dataloader_paths + + +class CTBxJointLoader(DataSetLoader): + """ + 文件夹下应该具有以下的文件结构 + -train.conllx + -dev.conllx + -test.conllx + 每个文件中的内容如下(空格隔开不同的句子, 共有) + 1 费孝通 _ NR NR _ 3 nsubjpass _ _ + 2 被 _ SB SB _ 3 pass _ _ + 3 授予 _ VV VV _ 0 root _ _ + 4 麦格赛赛 _ NR NR _ 5 nn _ _ + 5 奖 _ NN NN _ 3 dobj _ _ + + 1 新华社 _ NR NR _ 7 dep _ _ + 2 马尼拉 _ NR NR _ 7 dep _ _ + 3 8月 _ NT NT _ 7 dep _ _ + 4 31日 _ NT NT _ 7 dep _ _ + ... + + """ + def __init__(self): + self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) + + def load(self, path:str): + """ + 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 + words: list[str] + pos_tags: list[str] + heads: list[int] + labels: list[str] + + :param path: + :return: + """ + dataset = self._loader.load(path) + dataset.heads.int() + return dataset + + def process(self, paths): + """ + + :param paths: + :return: + Dataset包含以下的field + chars: + bigrams: + trigrams: + pre_chars: + pre_bigrams: + pre_trigrams: + seg_targets: + seg_masks: + seq_lens: + char_labels: + char_heads: + gold_word_pairs: + seg_targets: + seg_masks: + char_labels: + char_heads: + pun_masks: + gold_label_word_pairs: + """ + paths = check_dataloader_paths(paths) + data = DataInfo() + + for name, path in paths.items(): + dataset = self.load(path) + data.datasets[name] = dataset + + char_labels_vocab = Vocabulary(padding=None, unknown=None) + + def process(dataset, char_label_vocab): + dataset.apply(add_word_lst, new_field_name='word_lst') + dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') + dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') + dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') + dataset.apply(add_char_heads, new_field_name='char_heads') + dataset.apply(add_char_labels, new_field_name='char_labels') + dataset.apply(add_segs, new_field_name='seg_targets') + dataset.apply(add_mask, new_field_name='seg_masks') + dataset.add_seq_len('chars', new_field_name='seq_lens') + dataset.apply(add_pun_masks, new_field_name='pun_masks') + if len(char_label_vocab.word_count)==0: + char_label_vocab.from_dataset(dataset, field_name='char_labels') + char_label_vocab.index_dataset(dataset, field_name='char_labels') + new_dataset = add_root(dataset) + new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) + global add_label_word_pairs + add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) + new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) + + new_dataset.set_pad_val('char_labels', -1) + new_dataset.set_pad_val('char_heads', -1) + + return new_dataset + + for name in list(paths.keys()): + dataset = data.datasets[name] + dataset = process(dataset, char_labels_vocab) + data.datasets[name] = dataset + + data.vocabs['char_labels'] = char_labels_vocab + + char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') + bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') + trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') + + for name in ['chars', 'bigrams', 'trigrams']: + vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) + vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) + data.vocabs['pre_{}'.format(name)] = vocab + + for name, vocab in zip(['chars', 'bigrams', 'trigrams'], + [char_vocab, bigram_vocab, trigram_vocab]): + vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) + data.vocabs[name] = vocab + + for name, dataset in data.datasets.items(): + dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', + 'pre_bigrams', 'pre_trigrams') + dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', + 'char_heads', + 'pun_masks', 'gold_label_word_pairs') + + return data + + +def add_label_word_pairs(instance, label_vocab): + # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] + word_end_indexes = np.array(list(map(len, instance['word_lst']))) + word_end_indexes = np.cumsum(word_end_indexes).tolist() + word_end_indexes.insert(0, 0) + word_pairs = [] + labels = instance['labels'] + pos_tags = instance['pos_tags'] + for idx, head in enumerate(instance['heads']): + if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 + continue + label = label_vocab.to_index(labels[idx]) + if head==0: + word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) + else: + word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, + (word_end_indexes[idx], word_end_indexes[idx + 1]))) + return word_pairs + +def add_word_pairs(instance): + # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] + word_end_indexes = np.array(list(map(len, instance['word_lst']))) + word_end_indexes = np.cumsum(word_end_indexes).tolist() + word_end_indexes.insert(0, 0) + word_pairs = [] + pos_tags = instance['pos_tags'] + for idx, head in enumerate(instance['heads']): + if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 + continue + if head==0: + word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) + else: + word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), + (word_end_indexes[idx], word_end_indexes[idx + 1]))) + return word_pairs + +def add_root(dataset): + new_dataset = DataSet() + for sample in dataset: + chars = ['char_root'] + sample['chars'] + bigrams = ['bigram_root'] + sample['bigrams'] + trigrams = ['trigram_root'] + sample['trigrams'] + seq_lens = sample['seq_lens']+1 + char_labels = [0] + sample['char_labels'] + char_heads = [0] + sample['char_heads'] + sample['chars'] = chars + sample['bigrams'] = bigrams + sample['trigrams'] = trigrams + sample['seq_lens'] = seq_lens + sample['char_labels'] = char_labels + sample['char_heads'] = char_heads + new_dataset.append(sample) + return new_dataset + +def add_pun_masks(instance): + tags = instance['pos_tags'] + pun_masks = [] + for word, tag in zip(instance['words'], tags): + if tag=='PU': + pun_masks.extend([1]*len(word)) + else: + pun_masks.extend([0]*len(word)) + return pun_masks + +def add_word_lst(instance): + words = instance['words'] + word_lst = [list(word) for word in words] + return word_lst + +def add_bigram(instance): + chars = instance['chars'] + length = len(chars) + chars = chars + [''] + bigrams = [] + for i in range(length): + bigrams.append(''.join(chars[i:i + 2])) + return bigrams + +def add_trigram(instance): + chars = instance['chars'] + length = len(chars) + chars = chars + [''] * 2 + trigrams = [] + for i in range(length): + trigrams.append(''.join(chars[i:i + 3])) + return trigrams + +def add_char_heads(instance): + words = instance['word_lst'] + heads = instance['heads'] + char_heads = [] + char_index = 1 # 因此存在root节点所以需要从1开始 + head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 + for word, head in zip(words, heads): + char_head = [] + if len(word)>1: + char_head.append(char_index+1) + char_index += 1 + for _ in range(len(word)-2): + char_index += 1 + char_head.append(char_index) + char_index += 1 + char_head.append(head_end_indexes[head-1]) + char_heads.extend(char_head) + return char_heads + +def add_char_labels(instance): + """ + 将word_lst中的数据按照下面的方式设置label + 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. + 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label + :param instance: + :return: + """ + words = instance['word_lst'] + labels = instance['labels'] + char_labels = [] + for word, label in zip(words, labels): + for _ in range(len(word)-1): + char_labels.append('APP') + char_labels.append(label) + return char_labels + +# add seg_targets +def add_segs(instance): + words = instance['word_lst'] + segs = [0]*len(instance['chars']) + index = 0 + for word in words: + index = index + len(word) - 1 + segs[index] = len(word)-1 + index = index + 1 + return segs + +# add target_masks +def add_mask(instance): + words = instance['word_lst'] + mask = [] + for word in words: + mask.extend([0] * (len(word) - 1)) + mask.append(1) + return mask diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py new file mode 100644 index 00000000..1ed5ea2d --- /dev/null +++ b/reproduction/joint_cws_parse/models/CharParser.py @@ -0,0 +1,311 @@ + + + +from fastNLP.models.biaffine_parser import BiaffineParser +from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from fastNLP.modules.dropout import TimestepDropout +from fastNLP.modules.encoder.variational_rnn import VarLSTM +from fastNLP import seq_len_to_mask +from fastNLP.modules import Embedding + + +def drop_input_independent(word_embeddings, dropout_emb): + batch_size, seq_length, _ = word_embeddings.size() + word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) + word_masks = torch.bernoulli(word_masks) + word_masks = word_masks.unsqueeze(dim=2) + word_embeddings = word_embeddings * word_masks + + return word_embeddings + + +class CharBiaffineParser(BiaffineParser): + def __init__(self, char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers=3, + rnn_hidden_size=800, #单向的数量 + arc_mlp_size=500, + label_mlp_size=100, + dropout=0.3, + encoder='lstm', + use_greedy_infer=False, + app_index = 0, + pre_chars_embed=None, + pre_bigrams_embed=None, + pre_trigrams_embed=None): + + + super(BiaffineParser, self).__init__() + rnn_out_size = 2 * rnn_hidden_size + self.char_embed = Embedding((char_vocab_size, emb_dim)) + self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) + self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) + if pre_chars_embed: + self.pre_char_embed = Embedding(pre_chars_embed) + self.pre_char_embed.requires_grad = False + if pre_bigrams_embed: + self.pre_bigram_embed = Embedding(pre_bigrams_embed) + self.pre_bigram_embed.requires_grad = False + if pre_trigrams_embed: + self.pre_trigram_embed = Embedding(pre_trigrams_embed) + self.pre_trigram_embed.requires_grad = False + self.timestep_drop = TimestepDropout(dropout) + self.encoder_name = encoder + + if encoder == 'var-lstm': + self.encoder = VarLSTM(input_size=emb_dim*3, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + input_dropout=dropout, + hidden_dropout=dropout, + bidirectional=True) + elif encoder == 'lstm': + self.encoder = nn.LSTM(input_size=emb_dim*3, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True) + + else: + raise ValueError('unsupported encoder type: {}'.format(encoder)) + + self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), + nn.LeakyReLU(0.1), + TimestepDropout(p=dropout),) + self.arc_mlp_size = arc_mlp_size + self.label_mlp_size = label_mlp_size + self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) + self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) + self.use_greedy_infer = use_greedy_infer + self.reset_parameters() + self.dropout = dropout + + self.app_index = app_index + self.num_label = num_label + if self.app_index != 0: + raise ValueError("现在app_index必须等于0") + + def reset_parameters(self): + for name, m in self.named_modules(): + if 'embed' in name: + pass + elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): + pass + else: + for p in m.parameters(): + if len(p.size())>1: + nn.init.xavier_normal_(p, gain=0.1) + else: + nn.init.uniform_(p, -0.1, 0.1) + + def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, + pre_trigrams=None): + """ + max_len是包含root的 + :param chars: batch_size x max_len + :param ngrams: batch_size x max_len*ngram_per_char + :param seq_lens: batch_size + :param gold_heads: batch_size x max_len + :param pre_chars: batch_size x max_len + :param pre_ngrams: batch_size x max_len*ngram_per_char + :return dict: parsing results + arc_pred: [batch_size, seq_len, seq_len] + label_pred: [batch_size, seq_len, seq_len] + mask: [batch_size, seq_len] + head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads + """ + # prepare embeddings + batch_size, seq_len = chars.shape + # print('forward {} {}'.format(batch_size, seq_len)) + + # get sequence mask + mask = seq_len_to_mask(seq_lens).long() + + chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] + bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] + trigrams = self.trigram_embed(trigrams) + + if pre_chars is not None: + pre_chars = self.pre_char_embed(pre_chars) + # pre_chars = self.pre_char_fc(pre_chars) + chars = pre_chars + chars + if pre_bigrams is not None: + pre_bigrams = self.pre_bigram_embed(pre_bigrams) + # pre_bigrams = self.pre_bigram_fc(pre_bigrams) + bigrams = bigrams + pre_bigrams + if pre_trigrams is not None: + pre_trigrams = self.pre_trigram_embed(pre_trigrams) + # pre_trigrams = self.pre_trigram_fc(pre_trigrams) + trigrams = trigrams + pre_trigrams + + x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] + + # encoder, extract features + if self.training: + x = drop_input_independent(x, self.dropout) + sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + x = x[sort_idx] + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) + feat, _ = self.encoder(x) # -> [N,L,C] + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + feat = feat[unsort_idx] + feat = self.timestep_drop(feat) + + # for arc biaffine + # mlp, reduce dim + feat = self.mlp(feat) + arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size + arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] + label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] + + # biaffine arc classifier + arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] + + # use gold or predicted arc to predict label + if gold_heads is None or not self.training: + # use greedy decoding in training + if self.training or self.use_greedy_infer: + heads = self.greedy_decoder(arc_pred, mask) + else: + heads = self.mst_decoder(arc_pred, mask) + head_pred = heads + else: + assert self.training # must be training mode + if gold_heads is None: + heads = self.greedy_decoder(arc_pred, mask) + head_pred = heads + else: + head_pred = None + heads = gold_heads + # heads: batch_size x max_len + + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) + label_head = label_head[batch_range, heads].contiguous() + label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] + # 这里限制一下,只有当head为下一个时,才能预测app这个label + arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ + .repeat(batch_size, 1) # batch_size x max_len + app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app + app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) + app_masks[:, :, 1:] = 0 + label_pred = label_pred.masked_fill(app_masks, -np.inf) + + res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} + if head_pred is not None: + res_dict['head_pred'] = head_pred + return res_dict + + @staticmethod + def loss(arc_pred, label_pred, arc_true, label_true, mask): + """ + Compute loss. + + :param arc_pred: [batch_size, seq_len, seq_len] + :param label_pred: [batch_size, seq_len, n_tags] + :param arc_true: [batch_size, seq_len] + :param label_true: [batch_size, seq_len] + :param mask: [batch_size, seq_len] + :return: loss value + """ + + batch_size, seq_len, _ = arc_pred.shape + flip_mask = (mask == 0) + _arc_pred = arc_pred.clone() + _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) + + arc_true[:, 0].fill_(-1) + label_true[:, 0].fill_(-1) + + arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) + label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) + + return arc_nll + label_nll + + def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): + """ + + max_len是包含root的 + + :param chars: batch_size x max_len + :param ngrams: batch_size x max_len*ngram_per_char + :param seq_lens: batch_size + :param pre_chars: batch_size x max_len + :param pre_ngrams: batch_size x max_len*ngram_per_cha + :return: + """ + res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, + pre_trigrams=pre_trigrams, gold_heads=None) + output = {} + output['arc_pred'] = res.pop('head_pred') + _, label_pred = res.pop('label_pred').max(2) + output['label_pred'] = label_pred + return output + +class CharParser(nn.Module): + def __init__(self, char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers=3, + rnn_hidden_size=400, #单向的数量 + arc_mlp_size=500, + label_mlp_size=100, + dropout=0.3, + encoder='var-lstm', + use_greedy_infer=False, + app_index = 0, + pre_chars_embed=None, + pre_bigrams_embed=None, + pre_trigrams_embed=None): + super().__init__() + + self.parser = CharBiaffineParser(char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers, + rnn_hidden_size, #单向的数量 + arc_mlp_size, + label_mlp_size, + dropout, + encoder, + use_greedy_infer, + app_index, + pre_chars_embed=pre_chars_embed, + pre_bigrams_embed=pre_bigrams_embed, + pre_trigrams_embed=pre_trigrams_embed) + + def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, + pre_trigrams=None): + res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, + pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) + arc_pred = res_dict['arc_pred'] + label_pred = res_dict['label_pred'] + masks = res_dict['mask'] + loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) + return {'loss': loss} + + def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): + res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, + pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) + output = {} + output['head_preds'] = res.pop('head_pred') + _, label_pred = res.pop('label_pred').max(2) + output['label_preds'] = label_pred + return output diff --git a/reproduction/joint_cws_parse/models/__init__.py b/reproduction/joint_cws_parse/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/models/callbacks.py b/reproduction/joint_cws_parse/models/callbacks.py new file mode 100644 index 00000000..8de01109 --- /dev/null +++ b/reproduction/joint_cws_parse/models/callbacks.py @@ -0,0 +1,65 @@ + +from fastNLP.core.callback import Callback +import torch +from torch import nn + +class OptimizerCallback(Callback): + def __init__(self, optimizer, scheduler, update_every=4): + super().__init__() + + self._optimizer = optimizer + self.scheduler = scheduler + self._update_every = update_every + + def on_backward_end(self): + if self.step % self._update_every==0: + # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) + # self._optimizer.step() + self.scheduler.step() + # self.model.zero_grad() + + +class DevCallback(Callback): + def __init__(self, tester, metric_key='u_f1'): + super().__init__() + self.tester = tester + setattr(tester, 'verbose', 0) + + self.metric_key = metric_key + + self.record_best = False + self.best_eval_value = 0 + self.best_eval_res = None + + self.best_dev_res = None # 存取dev的表现 + + def on_valid_begin(self): + eval_res = self.tester.test() + metric_name = self.tester.metrics[0].__class__.__name__ + metric_value = eval_res[metric_name][self.metric_key] + if metric_value>self.best_eval_value: + self.best_eval_value = metric_value + self.best_epoch = self.trainer.epoch + self.record_best = True + self.best_eval_res = eval_res + self.test_eval_res = eval_res + eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ + self.tester._format_eval_results(eval_res) + self.pbar.write(eval_str) + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + if self.record_best: + self.best_dev_res = eval_result + self.record_best = False + if is_better_eval: + self.best_dev_res_on_dev = eval_result + self.best_test_res_on_dev = self.test_eval_res + self.dev_epoch = self.epoch + + def on_train_end(self): + print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, + self.tester._format_eval_results(self.best_eval_res), + self.tester._format_eval_results(self.best_dev_res))) + print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, + self.tester._format_eval_results(self.best_test_res_on_dev), + self.tester._format_eval_results(self.best_dev_res_on_dev))) \ No newline at end of file diff --git a/reproduction/joint_cws_parse/models/metrics.py b/reproduction/joint_cws_parse/models/metrics.py new file mode 100644 index 00000000..bf0f0622 --- /dev/null +++ b/reproduction/joint_cws_parse/models/metrics.py @@ -0,0 +1,184 @@ +from fastNLP.core.metrics import MetricBase +from fastNLP.core.utils import seq_len_to_mask +import torch + + +class SegAppCharParseF1Metric(MetricBase): + # + def __init__(self, app_index): + super().__init__() + self.app_index = app_index + + self.parse_head_tp = 0 + self.parse_label_tp = 0 + self.rec_tol = 0 + self.pre_tol = 0 + + def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, + pun_masks): + """ + + max_len是不包含root的character的长度 + :param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size + :param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size + :param head_preds: batch_size x max_len + :param label_preds: batch_size x max_len + :param seq_lens: + :param pun_masks: batch_size x + :return: + """ + # 去掉root + head_preds = head_preds[:, 1:].tolist() + label_preds = label_preds[:, 1:].tolist() + seq_lens = (seq_lens - 1).tolist() + + # 先解码出words,POS,heads, labels, 对应的character范围 + for b in range(len(head_preds)): + seq_len = seq_lens[b] + head_pred = head_preds[b][:seq_len] + label_pred = label_preds[b][:seq_len] + + words = [] # 存放[word_start, word_end),相对起始位置,不考虑root + heads = [] + labels = [] + ranges = [] # 对应该char是第几个word,长度是seq_len+1 + word_idx = 0 + word_start_idx = 0 + for idx, (label, head) in enumerate(zip(label_pred, head_pred)): + ranges.append(word_idx) + if label == self.app_index: + pass + else: + labels.append(label) + heads.append(head) + words.append((word_start_idx, idx+1)) + word_start_idx = idx+1 + word_idx += 1 + + head_dep_tuple = [] # head在前面 + head_label_dep_tuple = [] + for idx, head in enumerate(heads): + span = words[idx] + if span[0]==span[1]-1 and pun_masks[b, span[0]]: + continue # exclude punctuations + if head == 0: + head_dep_tuple.append((('root', words[idx]))) + head_label_dep_tuple.append(('root', labels[idx], words[idx])) + else: + head_word_idx = ranges[head-1] + head_word_span = words[head_word_idx] + head_dep_tuple.append(((head_word_span, words[idx]))) + head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) + + gold_head_dep_tuple = set(gold_word_pairs[b]) + gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) + + for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): + if head_dep in gold_head_dep_tuple: + self.parse_head_tp += 1 + if head_label_dep in gold_head_label_dep_tuple: + self.parse_label_tp += 1 + self.pre_tol += len(head_dep_tuple) + self.rec_tol += len(gold_head_dep_tuple) + + def get_metric(self, reset=True): + u_p = self.parse_head_tp / self.pre_tol + u_r = self.parse_head_tp / self.rec_tol + u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) + l_p = self.parse_label_tp / self.pre_tol + l_r = self.parse_label_tp / self.rec_tol + l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) + + if reset: + self.parse_head_tp = 0 + self.parse_label_tp = 0 + self.rec_tol = 0 + self.pre_tol = 0 + + return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), + 'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} + + +class CWSMetric(MetricBase): + def __init__(self, app_index): + super().__init__() + self.app_index = app_index + self.pre = 0 + self.rec = 0 + self.tp = 0 + + def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): + """ + + :param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 + :param seg_masks: batch_size x max_len,只有在word结束的地方为1 + :param label_preds: batch_size x max_len + :param seq_lens: batch_size + :return: + """ + + pred_masks = torch.zeros_like(seg_masks) + pred_segs = torch.zeros_like(seg_targets) + + seq_lens = (seq_lens - 1).tolist() + for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): + seq_len = seq_lens[idx] + label_pred = label_pred[:seq_len] + word_len = 0 + for l_i, label in enumerate(label_pred): + if label==self.app_index and l_i!=len(label_pred)-1: + word_len += 1 + else: + pred_segs[idx, l_i] = word_len # 这个词的长度为word_len + pred_masks[idx, l_i] = 1 + word_len = 0 + + right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 + self.rec += seg_masks.sum().item() + self.pre += pred_masks.sum().item() + # 且pred和target在同一个地方有值 + self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() + + def get_metric(self, reset=True): + res = {} + res['rec'] = round(self.tp/(self.rec+1e-6), 4) + res['pre'] = round(self.tp/(self.pre+1e-6), 4) + res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) + + if reset: + self.pre = 0 + self.rec = 0 + self.tp = 0 + + return res + + +class ParserMetric(MetricBase): + def __init__(self, ): + super().__init__() + self.num_arc = 0 + self.num_label = 0 + self.num_sample = 0 + + def get_metric(self, reset=True): + res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), + 'LAS': round(self.num_label*1.0 / self.num_sample, 4)} + if reset: + self.num_sample = self.num_label = self.num_arc = 0 + return res + + def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): + """Evaluate the performance of prediction. + """ + if seq_lens is None: + seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) + else: + seq_mask = seq_len_to_mask(seq_lens.long(), float=False) + # mask out tag + seq_mask[:, 0] = 0 + head_pred_correct = (head_preds == heads).__and__(seq_mask) + label_pred_correct = (label_preds == labels).__and__(head_pred_correct) + self.num_arc += head_pred_correct.float().sum().item() + self.num_label += label_pred_correct.float().sum().item() + self.num_sample += seq_mask.sum().item() + diff --git a/reproduction/joint_cws_parse/readme.md b/reproduction/joint_cws_parse/readme.md new file mode 100644 index 00000000..7fe77b47 --- /dev/null +++ b/reproduction/joint_cws_parse/readme.md @@ -0,0 +1,16 @@ +Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) + +### 准备数据 +1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. +2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 +3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 + + +### 运行代码 +``` +python train.py +``` + +### 其它 +ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 +learning rate scheduler. \ No newline at end of file diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py new file mode 100644 index 00000000..2f8b0d04 --- /dev/null +++ b/reproduction/joint_cws_parse/train.py @@ -0,0 +1,124 @@ +import sys +sys.path.append('../..') + +from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from torch import nn +from functools import partial +from reproduction.joint_cws_parse.models.CharParser import CharParser +from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric +from fastNLP import cache_results, BucketSampler, Trainer +from torch import optim +from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback +from torch.optim.lr_scheduler import LambdaLR, StepLR +from fastNLP import Tester +from fastNLP import GradientClipCallback, LRScheduler +import os + +def set_random_seed(random_seed=666): + import random, numpy, torch + random.seed(random_seed) + numpy.random.seed(random_seed) + torch.cuda.manual_seed(random_seed) + torch.random.manual_seed(random_seed) + +uniform_init = partial(nn.init.normal_, std=0.02) + +################################################### +# 需要变动的超参放到这里 +lr = 0.002 # 0.01~0.001 +dropout = 0.33 # 0.3~0.6 +weight_decay = 0 # 1e-5, 1e-6, 0 +arc_mlp_size = 500 # 200, 300 +rnn_hidden_size = 400 # 200, 300, 400 +rnn_layers = 3 # 2, 3 +encoder = 'var-lstm' # var-lstm, lstm +emb_size = 100 # 64 , 100 +label_mlp_size = 100 + +batch_size = 32 +update_every = 4 +n_epochs = 100 +data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 +vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt +#################################################### + +set_random_seed(1234) +device = 0 + +# @cache_results('caches/{}.pkl'.format(data_name)) +# def get_data(): +data = CTBxJointLoader().process(data_folder) + +char_labels_vocab = data.vocabs['char_labels'] + +pre_chars_vocab = data.vocabs['pre_chars'] +pre_bigrams_vocab = data.vocabs['pre_bigrams'] +pre_trigrams_vocab = data.vocabs['pre_trigrams'] + +chars_vocab = data.vocabs['chars'] +bigrams_vocab = data.vocabs['bigrams'] +trigrams_vocab = data.vocabs['trigrams'] + +pre_chars_embed = StaticEmbedding(pre_chars_vocab, + model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() +pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() +pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() + + # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data + +# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() + +print(data) +model = CharParser(char_vocab_size=len(chars_vocab), + emb_dim=emb_size, + bigram_vocab_size=len(bigrams_vocab), + trigram_vocab_size=len(trigrams_vocab), + num_label=len(char_labels_vocab), + rnn_layers=rnn_layers, + rnn_hidden_size=rnn_hidden_size, + arc_mlp_size=arc_mlp_size, + label_mlp_size=label_mlp_size, + dropout=dropout, + encoder=encoder, + use_greedy_infer=False, + app_index=char_labels_vocab['APP'], + pre_chars_embed=pre_chars_embed, + pre_bigrams_embed=pre_bigrams_embed, + pre_trigrams_embed=pre_trigrams_embed) + +metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) +metric2 = CWSMetric(char_labels_vocab['APP']) +metrics = [metric1, metric2] + +optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, + weight_decay=weight_decay, betas=[0.9, 0.9]) + +sampler = BucketSampler(seq_len_field_name='seq_lens') +callbacks = [] +# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) +scheduler = StepLR(optimizer, step_size=18, gamma=0.75) +# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) +# callbacks.append(optim_callback) +scheduler_callback = LRScheduler(scheduler) +callbacks.append(scheduler_callback) +callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) + +tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, + batch_size=64, device=device, verbose=0) +dev_callback = DevCallback(tester) +callbacks.append(dev_callback) + +trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, + validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, + check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, + device=device, callbacks=callbacks, update_every=update_every) +trainer.train() \ No newline at end of file diff --git a/reproduction/matching/README.md b/reproduction/matching/README.md new file mode 100644 index 00000000..056b0212 --- /dev/null +++ b/reproduction/matching/README.md @@ -0,0 +1,100 @@ +# Matching任务模型复现 +这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). + +复现的模型有(按论文发表时间顺序排序): +- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). +论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). +- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). +论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). +- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). +论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). +- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). +论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). +- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). +论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). + +# 数据集及复现结果汇总 + +使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 + +'\-'表示我们仍未复现或者论文原文没有汇报 + +model name | SNLI | MNLI | RTE | QNLI | Quora +:---: | :---: | :---: | :---: | :---: | :---: +CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | +ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | +DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | +MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | +BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | + +*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 + +# 数据集复现结果及其他主要模型对比 +## SNLI +[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) + +Performance on Test set: + +model name | ESIM | DIIN | MwAN | [GPT1.0](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | [BERT-Large+SRL](https://arxiv.org/pdf/1809.02794.pdf) | [MT-DNN](https://arxiv.org/pdf/1901.11504.pdf) +:---: | :---: | :---: | :---: | :---: | :---: | :---: +__performance__ | 88.0 | 88.0 | 88.3 | 89.9 | 91.3 | 91.6 | + +### 基于fastNLP复现的结果 +Performance on Test set: + +model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large +:---: | :---: | :---: | :---: | :---: | :---: | :---: +__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 + +## MNLI +[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) + +Performance on Test set(matched/mismatched): + +model name | ESIM | DIIN | MwAN | GPT1.0 | BERT-Base | MT-DNN +:---: | :---: | :---: | :---: | :---: | :---: | :---: +__performance__ | 72.4/72.1 | 78.8/77.8 | 78.5/77.7 | 82.1/81.4 | 84.6/83.4 | 87.9/87.4 | + +### 基于fastNLP复现的结果 +Performance on Test set(matched/mismatched): + +model name | CNTN | ESIM | DIIN | MwAN | BERT-Base +:---: | :---: | :---: | :---: | :---: | :---: | +__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | + + +## RTE + +Still in progress. + +## QNLI + +### From GLUE baselines +[Link to GLUE leaderboard](https://gluebenchmark.com/leaderboard) + +Performance on Test set: +#### LSTM-based +model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo +:---: | :---: | :---: | :---: | :---: | +__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | + +*这些LSTM-based的baseline是由QNLI官方实现并测试的。 + +#### Transformer-based +model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN +:---: | :---: | :---: | :---: | :---: | +__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | + + + +### 基于fastNLP复现的结果 +Performance on __Dev__ set: + +model name | CNTN | ESIM | DIIN | MwAN | BERT +:---: | :---: | :---: | :---: | :---: | :---: +__performance__ | - | 76.97 | - | 74.6 | - + +## Quora + +Still in progress. + diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 749b16c8..7c32899c 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -5,8 +5,8 @@ from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataInfo -from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer @@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader): 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` 读取Matching任务的数据集 + + :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 """ def __init__(self, paths: dict=None): - """ - :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 - """ self.paths = paths def _load(self, path): @@ -34,7 +33,8 @@ class MatchingLoader(DataSetLoader): def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, - cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, + cut_text: int = None, get_index=True, auto_pad_length: int=None, + auto_pad_token: str='', set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, @@ -49,6 +49,8 @@ class MatchingLoader(DataSetLoader): :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 :param bool get_index: 是否需要根据词表将文本转为index + :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad + :param str auto_pad_token: 自动pad的内容 :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, 于此同时其他field不会被设置为input。默认值为True。 @@ -169,6 +171,9 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + if auto_pad_length is not None: + cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) + if cut_text is not None: for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): @@ -180,7 +185,7 @@ class MatchingLoader(DataSetLoader): assert len(data_set_list) > 0, f'There are NO data sets in data info!' if bert_tokenizer is None: - words_vocab = Vocabulary() + words_vocab = Vocabulary(padding=auto_pad_token) words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], field_name=[n for n in data_set_list[0].get_field_names() if (Const.INPUT in n)], @@ -202,6 +207,20 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, is_input=auto_set_input, is_target=auto_set_target) + if auto_pad_length is not None: + if seq_len_type == 'seq_len': + raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' + f'so the seq_len_type cannot be `{seq_len_type}`!') + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * + (auto_pad_length - len(x[fields])), new_field_name=fields, + is_input=auto_set_input) + elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): + data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), + new_field_name=fields, is_input=auto_set_input) + for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) @@ -267,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', - # 'test': 'test.tsv' # test set has not label + 'test': 'test.tsv' # test set has not label } MatchingLoader.__init__(self, paths=paths) self.fields = { @@ -281,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -306,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', - # 'test': 'test.tsv' # test set has not label + 'test': 'test.tsv' # test set has not label } MatchingLoader.__init__(self, paths=paths) self.fields = { @@ -320,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -332,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader): """ 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` - 读取SNLI数据集,读取的DataSet包含fields:: + 读取MNLI数据集,读取的DataSet包含fields:: words1: list(str),第一句文本, premise words2: list(str), 第二句文本, hypothesis @@ -348,6 +369,10 @@ class MNLILoader(MatchingLoader, CSVLoader): 'dev_mismatched': 'dev_mismatched.tsv', 'test_matched': 'test_matched.tsv', 'test_mismatched': 'test_mismatched.tsv', + # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', + # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) } MatchingLoader.__init__(self, paths=paths) CSVLoader.__init__(self, sep='\t') @@ -364,6 +389,10 @@ class MNLILoader(MatchingLoader, CSVLoader): if k in ds.get_field_names(): ds.rename_field(k, v) + if Const.TARGET in ds.get_field_names(): + if ds[0][Const.TARGET] == 'hidden': + ds.delete_field(Const.TARGET) + parentheses_table = str.maketrans({'(': None, ')': None}) ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), @@ -376,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader): class QuoraLoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ def __init__(self, paths: dict=None): paths = paths if paths is not None else { diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py new file mode 100644 index 00000000..75112d5a --- /dev/null +++ b/reproduction/matching/matching_bert.py @@ -0,0 +1,102 @@ +import random +import numpy as np +import torch + +from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam + +from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ + MNLILoader, QNLILoader, QuoraLoader +from reproduction.matching.model.bert import BertForNLI + + +# define hyper-parameters +class BERTConfig: + + task = 'snli' + batch_size_per_gpu = 6 + n_epochs = 6 + lr = 2e-5 + seq_len_type = 'bert' + seed = 42 + train_dataset_name = 'train' + dev_dataset_name = 'dev' + test_dataset_name = 'test' + save_path = None # 模型存储的位置,None表示不存储模型。 + bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 + + +arg = BERTConfig() + +# set random seed +random.seed(arg.seed) +np.random.seed(arg.seed) +torch.manual_seed(arg.seed) + +n_gpu = torch.cuda.device_count() +if n_gpu > 0: + torch.cuda.manual_seed_all(arg.seed) + +# load data set +if arg.task == 'snli': + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'rte': + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'mnli': + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'quora': + data_info = QuoraLoader().process( + paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +else: + raise RuntimeError(f'NOT support {arg.task} task yet!') + +# define model +model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) + +# define trainer +trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, + optimizer=Adam(lr=arg.lr, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + n_epochs=arg.n_epochs, print_every=-1, + dev_data=data_info.datasets[arg.dev_dataset_name], + metrics=AccuracyMetric(), metric_key='acc', + device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1, + save_path=arg.save_path) + +# train model +trainer.train(load_best_model=True) + +# define tester +tester = Tester( + data=data_info.datasets[arg.test_dataset_name], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + device=[i for i in range(torch.cuda.device_count())], +) + +# test model +tester.test() + + diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py index 3da6141f..d878608f 100644 --- a/reproduction/matching/matching_esim.py +++ b/reproduction/matching/matching_esim.py @@ -1,47 +1,103 @@ -import argparse +import random +import numpy as np import torch +from torch.optim import Adamax +from torch.optim.lr_scheduler import StepLR -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.core import Trainer, Tester, AccuracyMetric, Const +from fastNLP.core.callback import GradientClipCallback, LRScheduler from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding -from reproduction.matching.data.MatchingDataLoader import SNLILoader +from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ + MNLILoader, QNLILoader, QuoraLoader from reproduction.matching.model.esim import ESIMModel -argument = argparse.ArgumentParser() -argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') -argument.add_argument('--batch-size-per-gpu', type=int, default=128) -argument.add_argument('--n-epochs', type=int, default=100) -argument.add_argument('--lr', type=float, default=1e-4) -argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') -argument.add_argument('--save-dir', type=str, default=None) -arg = argument.parse_args() -bert_dirs = 'path/to/bert/dir' +# define hyper-parameters +class ESIMConfig: + + task = 'snli' + embedding = 'glove' + batch_size_per_gpu = 196 + n_epochs = 30 + lr = 2e-3 + seq_len_type = 'seq_len' + # seq_len表示在process的时候用len(words)来表示长度信息; + # mask表示用0/1掩码矩阵来表示长度信息; + seed = 42 + train_dataset_name = 'train' + dev_dataset_name = 'dev' + test_dataset_name = 'test' + save_path = None # 模型存储的位置,None表示不存储模型。 + + +arg = ESIMConfig() + +# set random seed +random.seed(arg.seed) +np.random.seed(arg.seed) +torch.manual_seed(arg.seed) + +n_gpu = torch.cuda.device_count() +if n_gpu > 0: + torch.cuda.manual_seed_all(arg.seed) # load data set -data_info = SNLILoader().process( - paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, -) +if arg.task == 'snli': + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'rte': + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'mnli': + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'quora': + data_info = QuoraLoader().process( + paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +else: + raise RuntimeError(f'NOT support {arg.task} task yet!') # load embedding if arg.embedding == 'elmo': embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) elif arg.embedding == 'glove': - embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) + embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) else: - raise ValueError(f'now we only support elmo or glove embedding for esim model!') + raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') # define model -model = ESIMModel(embedding) +model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) + +# define optimizer and callback +optimizer = Adamax(lr=arg.lr, params=model.parameters()) +scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 + +callbacks = [ + GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) + LRScheduler(scheduler), +] # define trainer -trainer = Trainer(train_data=data_info.datasets['train'], model=model, - optimizer=Adam(lr=arg.lr, model_params=model.parameters()), +trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, + optimizer=optimizer, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_info.datasets['dev'], + dev_data=data_info.datasets[arg.dev_dataset_name], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, @@ -52,7 +108,7 @@ trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_info.datasets['test'], + data=data_info.datasets[arg.test_dataset_name], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py index d55034e7..187e565d 100644 --- a/reproduction/matching/model/esim.py +++ b/reproduction/matching/model/esim.py @@ -81,6 +81,7 @@ class ESIMModel(BaseModel): out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] logits = torch.tanh(self.classifier(out)) + # logits = self.classifier(out) if target is not None: loss_fct = CrossEntropyLoss() @@ -91,7 +92,8 @@ class ESIMModel(BaseModel): return {Const.OUTPUT: logits} def predict(self, **kwargs): - return self.forward(**kwargs) + pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) + return {Const.OUTPUT: pred} # input [batch_size, len , hidden] # mask [batch_size, len] (111...00) @@ -127,7 +129,7 @@ class BiRNN(nn.Module): def forward(self, x, x_mask): # Sort x - lengths = x_mask.data.eq(1).long().sum(1).squeeze() + lengths = x_mask.data.eq(1).long().sum(1) _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort]) diff --git a/reproduction/seqence_labelling/ner/README.md b/reproduction/seqence_labelling/ner/README.md new file mode 100644 index 00000000..d42046b0 --- /dev/null +++ b/reproduction/seqence_labelling/ner/README.md @@ -0,0 +1,13 @@ +# NER任务模型复现 +这里使用fastNLP复现经典的BiLSTM-CNN的NER任务的模型,旨在达到与论文中相符的性能。 + +论文链接[Named Entity Recognition with Bidirectional LSTM-CNNs](https://arxiv.org/pdf/1511.08308.pdf) + +# 数据集及复现结果汇总 + +使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道) + +model name | Conll2003 | Ontonotes +:---: | :---: | :---: +BiLSTM-CNN | 91.17/90.91 | 86.47/86.35 | + diff --git a/reproduction/utils.py b/reproduction/utils.py index 26b2014c..4f0d021e 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: } 如果paths为不合法的,将直接进行raise相应的错误 - :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, - test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ if isinstance(paths, str): if os.path.isfile(paths): return {'train': paths} elif os.path.isdir(paths): - train_fp = os.path.join(paths, 'train.txt') - if not os.path.isfile(train_fp): - raise FileNotFoundError(f"train.txt is not found in folder {paths}.") - files = {'train': train_fp} - for filename in ['dev.txt', 'test.txt']: - fp = os.path.join(paths, filename) - if os.path.isfile(fp): - files[filename.split('.')[0]] = fp + filenames = os.listdir(paths) + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + files[path_pair[0]] = os.path.join(paths, path_pair[1]) return files else: raise FileNotFoundError(f"{paths} is not a valid file path.") diff --git a/requirements.txt b/requirements.txt index 7ea8fdac..f8f7a951 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -numpy -torch>=0.4.0 -tqdm -nltk +numpy>=1.14.2 +torch>=1.0.0 +tqdm>=4.28.1 +nltk>=3.4.1 requests diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 7cff3c12..09ad8c83 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -1,7 +1,7 @@ import unittest import os -from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader -from fastNLP.io.dataset_loader import SSTLoader +from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader +from fastNLP.io.data_loader import SSTLoader, SNLILoader from reproduction.text_classification.data.yelpLoader import yelpLoader @@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): print(info.vocabs) print(info.datasets) os.remove(train), os.remove(test) + + def test_import(self): + import fastNLP + from fastNLP.io import SNLILoader + ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, + get_index=True, seq_len_type='seq_len') + assert 'train' in ds.datasets + assert len(ds.datasets) == 1 + assert len(ds.datasets['train']) == 3 diff --git a/test/models/test_bert.py b/test/models/test_bert.py index 7177f31b..38a16f9b 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -8,8 +8,9 @@ from fastNLP.models.bert import * class TestBert(unittest.TestCase): def test_bert_1(self): from fastNLP.core.const import Const + from fastNLP.modules.encoder._bert import BertConfig - model = BertForSequenceClassification(2) + model = BertForSequenceClassification(2, BertConfig(32000)) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) @@ -22,8 +23,9 @@ class TestBert(unittest.TestCase): def test_bert_2(self): from fastNLP.core.const import Const + from fastNLP.modules.encoder._bert import BertConfig - model = BertForMultipleChoice(2) + model = BertForMultipleChoice(2, BertConfig(32000)) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) @@ -36,8 +38,9 @@ class TestBert(unittest.TestCase): def test_bert_3(self): from fastNLP.core.const import Const + from fastNLP.modules.encoder._bert import BertConfig - model = BertForTokenClassification(7) + model = BertForTokenClassification(7, BertConfig(32000)) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) @@ -50,8 +53,9 @@ class TestBert(unittest.TestCase): def test_bert_4(self): from fastNLP.core.const import Const + from fastNLP.modules.encoder._bert import BertConfig - model = BertForQuestionAnswering() + model = BertForQuestionAnswering(BertConfig(32000)) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) diff --git a/test/modules/encoder/test_bert.py b/test/modules/encoder/test_bert.py index 78bcf633..2a799478 100644 --- a/test/modules/encoder/test_bert.py +++ b/test/modules/encoder/test_bert.py @@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel class TestBert(unittest.TestCase): def test_bert_1(self): - model = BertModel(vocab_size=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + from fastNLP.modules.encoder._bert import BertConfig + config = BertConfig(32000) + model = BertModel(config) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) @@ -18,4 +19,4 @@ class TestBert(unittest.TestCase): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) for layer in all_encoder_layers: self.assertEqual(tuple(layer.shape), (2, 3, 768)) - self.assertEqual(tuple(pooled_output.shape), (2, 768)) \ No newline at end of file + self.assertEqual(tuple(pooled_output.shape), (2, 768)) diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 87910c3d..6f4a8347 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase): train_data.rename_field('label', 'label_seq') test_data.rename_field('label', 'label_seq') - loss = CrossEntropyLoss(pred="output", target="label_seq") + loss = CrossEntropyLoss(target="label_seq") metric = AccuracyMetric(target="label_seq") # 实例化Trainer,传入模型和数据,进行训练 @@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase): # 用train_data训练,在test_data验证 trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, - loss=CrossEntropyLoss(pred="output", target="label_seq"), + loss=CrossEntropyLoss(target="label_seq"), metrics=AccuracyMetric(target="label_seq"), save_path=None, batch_size=32,