diff --git a/README.md b/README.md
index 9d949482..a5ce3c64 100644
--- a/README.md
+++ b/README.md
@@ -6,13 +6,14 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
-fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性:
+fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性:
-- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。
-- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等;
-- 详尽的中文文档以供查阅;
+- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码;
+- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等;
+- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等;
+- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅;
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等;
-- 封装CNNText,Biaffine等模型可供直接使用;
+- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/)
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。
@@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地
fastNLP 依赖如下包:
-+ numpy
-+ torch>=0.4.0
-+ tqdm
-+ nltk
++ numpy>=1.14.2
++ torch>=1.0.0
++ tqdm>=4.28.1
++ nltk>=3.4.1
++ requests
-其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。
-在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装
+其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
+在依赖包安装完成后,您可以在命令行执行如下指令完成安装
```shell
pip install fastNLP
@@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。
你可以在以下两个地方查看相关信息
-- [介绍](reproduction/)
-- [源码](fastNLP/models/)
+- [模型介绍](reproduction/)
+- [模型源码](fastNLP/models/)
## 项目结构
@@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
fastNLP.core |
- 实现了核心功能,包括数据处理组件、训练器、测速器等 |
+ 实现了核心功能,包括数据处理组件、训练器、测试器等 |
fastNLP.models |
diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py
new file mode 100644
index 00000000..4a7757d3
--- /dev/null
+++ b/fastNLP/core/_parallel_utils.py
@@ -0,0 +1,88 @@
+
+import threading
+import torch
+from torch.nn.parallel.parallel_apply import get_a_var
+
+from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
+from torch.nn.parallel.replicate import replicate
+
+
+def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
+ r"""Applies each `module` in :attr:`modules` in parallel on arguments
+ contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
+ on each of :attr:`devices`.
+
+ :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
+ :attr:`devices` (if given) should all have same length. Moreover, each
+ element of :attr:`inputs` can either be a single object as the only argument
+ to a module, or a collection of positional arguments.
+ """
+ assert len(modules) == len(inputs)
+ if kwargs_tup is not None:
+ assert len(modules) == len(kwargs_tup)
+ else:
+ kwargs_tup = ({},) * len(modules)
+ if devices is not None:
+ assert len(modules) == len(devices)
+ else:
+ devices = [None] * len(modules)
+
+ lock = threading.Lock()
+ results = {}
+ grad_enabled = torch.is_grad_enabled()
+
+ def _worker(i, module, input, kwargs, device=None):
+ torch.set_grad_enabled(grad_enabled)
+ if device is None:
+ device = get_a_var(input).get_device()
+ try:
+ with torch.cuda.device(device):
+ # this also avoids accidental slicing of `input` if it is a Tensor
+ if not isinstance(input, (list, tuple)):
+ input = (input,)
+ output = getattr(module, func_name)(*input, **kwargs)
+ with lock:
+ results[i] = output
+ except Exception as e:
+ with lock:
+ results[i] = e
+
+ if len(modules) > 1:
+ threads = [threading.Thread(target=_worker,
+ args=(i, module, input, kwargs, device))
+ for i, (module, input, kwargs, device) in
+ enumerate(zip(modules, inputs, kwargs_tup, devices))]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ else:
+ _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
+
+ outputs = []
+ for i in range(len(inputs)):
+ output = results[i]
+ if isinstance(output, Exception):
+ raise output
+ outputs.append(output)
+ return outputs
+
+
+def _data_parallel_wrapper(func_name, device_ids, output_device):
+ """
+ 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数
+
+ :param str, func_name: 对network中的这个函数进行多卡运行
+ :param device_ids: nn.DataParallel中的device_ids
+ :param output_device: nn.DataParallel中的output_device
+ :return:
+ """
+ def wrapper(network, *inputs, **kwargs):
+ inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
+ if len(device_ids) == 1:
+ return getattr(network, func_name)(*inputs[0], **kwargs[0])
+ replicas = replicate(network, device_ids[:len(inputs)])
+ outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
+ return gather(outputs, output_device)
+ return wrapper
diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py
index 5dfd889b..0b1890f8 100644
--- a/fastNLP/core/callback.py
+++ b/fastNLP/core/callback.py
@@ -113,7 +113,7 @@ class Callback(object):
@property
def n_steps(self):
- """Trainer一共会运行多少步"""
+ """Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数"""
return self._trainer.n_steps
@property
@@ -181,7 +181,7 @@ class Callback(object):
:param dict batch_x: DataSet中被设置为input的field的batch。
:param dict batch_y: DataSet中被设置为target的field的batch。
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
- 情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。
+ 情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。
:return:
"""
pass
@@ -399,10 +399,11 @@ class GradientClipCallback(Callback):
self.clip_value = clip_value
def on_backward_end(self):
- if self.parameters is None:
- self.clip_fun(self.model.parameters(), self.clip_value)
- else:
- self.clip_fun(self.parameters, self.clip_value)
+ if self.step%self.update_every==0:
+ if self.parameters is None:
+ self.clip_fun(self.model.parameters(), self.clip_value)
+ else:
+ self.clip_fun(self.parameters, self.clip_value)
class EarlyStopCallback(Callback):
diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py
index 526bf37a..14aacef0 100644
--- a/fastNLP/core/losses.py
+++ b/fastNLP/core/losses.py
@@ -20,6 +20,7 @@ from collections import defaultdict
import torch
import torch.nn.functional as F
+from ..core.const import Const
from .utils import _CheckError
from .utils import _CheckRes
from .utils import _build_args
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method
from .utils import _get_func_signature
from .utils import seq_len_to_mask
+
class LossBase(object):
"""
所有loss的基类。如果想了解其中的原理,请查看源码。
@@ -95,22 +97,7 @@ class LossBase(object):
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
# f"positional argument.).")
-
- def _fast_param_map(self, pred_dict, target_dict):
- """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
- such as pred_dict has one element, target_dict has one element
- :param pred_dict:
- :param target_dict:
- :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
- """
- fast_param = {}
- if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
- fast_param['pred'] = list(pred_dict.values())[0]
- fast_param['target'] = list(target_dict.values())[0]
- return fast_param
- return fast_param
-
def __call__(self, pred_dict, target_dict, check=False):
"""
:param dict pred_dict: 模型的forward函数返回的dict
@@ -118,11 +105,7 @@ class LossBase(object):
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查
:return:
"""
- fast_param = self._fast_param_map(pred_dict, target_dict)
- if fast_param:
- loss = self.get_loss(**fast_param)
- return loss
-
+
if not self._checked:
# 1. check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.get_loss)
@@ -212,7 +195,6 @@ class LossFunc(LossBase):
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self._init_param_map(key_map, **kwargs)
-
class CrossEntropyLoss(LossBase):
@@ -226,6 +208,7 @@ class CrossEntropyLoss(LossBase):
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
+ :param str reduction: 支持'mean','sum'和'none'.
Example::
@@ -233,21 +216,25 @@ class CrossEntropyLoss(LossBase):
"""
- def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100):
+ def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
def get_loss(self, pred, target, seq_len=None):
- if pred.dim()>2:
- pred = pred.view(-1, pred.size(-1))
- target = target.view(-1)
+ if pred.dim() > 2:
+ if pred.size(1) != target.size(1):
+ pred = pred.transpose(1, 2)
+ pred = pred.reshape(-1, pred.size(-1))
+ target = target.reshape(-1)
if seq_len is not None:
- mask = seq_len_to_mask(seq_len).view(-1).eq(0)
+ mask = seq_len_to_mask(seq_len).reshape(-1).eq(0)
target = target.masked_fill(mask, self.padding_idx)
return F.cross_entropy(input=pred, target=target,
- ignore_index=self.padding_idx)
+ ignore_index=self.padding_idx, reduction=self.reduction)
class L1Loss(LossBase):
@@ -258,15 +245,18 @@ class L1Loss(LossBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, reduction='mean'):
super(L1Loss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
def get_loss(self, pred, target):
- return F.l1_loss(input=pred, target=target)
+ return F.l1_loss(input=pred, target=target, reduction=self.reduction)
class BCELoss(LossBase):
@@ -277,14 +267,17 @@ class BCELoss(LossBase):
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, reduction='mean'):
super(BCELoss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
def get_loss(self, pred, target):
- return F.binary_cross_entropy(input=pred, target=target)
+ return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction)
class NLLLoss(LossBase):
@@ -295,14 +288,20 @@ class NLLLoss(LossBase):
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
+ :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替
+ 传入seq_len.
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'):
super(NLLLoss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
+ self.ignore_idx = ignore_idx
def get_loss(self, pred, target):
- return F.nll_loss(input=pred, target=target)
+ return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction)
class LossInForward(LossBase):
@@ -314,7 +313,7 @@ class LossInForward(LossBase):
:param str loss_key: 在forward函数中loss的键名,默认为loss
"""
- def __init__(self, loss_key='loss'):
+ def __init__(self, loss_key=Const.LOSS):
super().__init__()
if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py
index 0849b35d..1fe035bf 100644
--- a/fastNLP/core/optimizer.py
+++ b/fastNLP/core/optimizer.py
@@ -9,6 +9,9 @@ __all__ = [
]
import torch
+import math
+import torch
+from torch.optim.optimizer import Optimizer as TorchOptimizer
class Optimizer(object):
@@ -97,3 +100,110 @@ class Adam(Optimizer):
return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings)
else:
return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings)
+
+
+class AdamW(TorchOptimizer):
+ r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.99))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad)
+ super(AdamW, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(AdamW, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ # Perform stepweight decay
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
+
+ # Perform optimization step
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+ amsgrad = group['amsgrad']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+ else:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+
+ p.data.addcdiv_(-step_size, exp_avg, denom)
+
+ return loss
diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py
index ce016bb6..2d6a7380 100644
--- a/fastNLP/core/predictor.py
+++ b/fastNLP/core/predictor.py
@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device
class Predictor(object):
"""
- An interface for predicting outputs based on trained models.
+ 一个根据训练模型预测输出的预测器(Predictor)
- It does not care about evaluations of the model, which is different from Tester.
- This is a high-level model wrapper to be called by FastNLP.
- This class does not share any operations with Trainer and Tester.
- Currently, Predictor does not support GPU.
+ 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
+ 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
+
+ :param torch.nn.Module network: 用来完成预测任务的模型
"""
def __init__(self, network):
@@ -30,18 +30,19 @@ class Predictor(object):
self.batch_size = 1
self.batch_output = []
- def predict(self, data, seq_len_field_name=None):
- """Perform inference using the trained model.
+ def predict(self, data: DataSet, seq_len_field_name=None):
+ """用已经训练好的模型进行inference.
- :param data: a DataSet object.
- :param str seq_len_field_name: field name indicating sequence lengths
- :return: list of batch outputs
+ :param fastNLP.DataSet data: 待预测的数据集
+ :param str seq_len_field_name: 表示序列长度信息的field名字
+ :return: dict dict里面的内容为模型预测的结果
"""
if not isinstance(data, DataSet):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
+ prev_training = self.network.training
self.network.eval()
network_device = _get_model_device(self.network)
batch_output = defaultdict(list)
@@ -74,4 +75,5 @@ class Predictor(object):
else:
batch_output[key].append(value)
+ self.network.train(prev_training)
return batch_output
diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py
index 4cdd4ffb..7048d0ae 100644
--- a/fastNLP/core/tester.py
+++ b/fastNLP/core/tester.py
@@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation
"""
-import warnings
-
import torch
import torch.nn as nn
@@ -48,6 +46,8 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
+from ._parallel_utils import _data_parallel_wrapper
+from functools import partial
__all__ = [
"Tester"
@@ -104,26 +104,27 @@ class Tester(object):
self.data_iterator = data
else:
raise TypeError("data type {} not support".format(type(data)))
-
- # 如果是DataParallel将没有办法使用predict方法
- if isinstance(self._model, nn.DataParallel):
- if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'):
- warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function,"
- " while DataParallel has no predict() function.")
- self._model = self._model.module
-
+
# check predict
- if hasattr(self._model, 'predict'):
- self._predict_func = self._model.predict
- if not callable(self._predict_func):
- _model_name = model.__class__.__name__
- raise TypeError(f"`{_model_name}.predict` must be callable to be used "
- f"for evaluation, not `{type(self._predict_func)}`.")
+ if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
+ (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and
+ callable(self._model.module.predict)):
+ if isinstance(self._model, nn.DataParallel):
+ self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
+ self._model.device_ids,
+ self._model.output_device),
+ network=self._model.module)
+ self._predict_func = self._model.module.predict
+ else:
+ self._predict_func = self._model.predict
+ self._predict_func_wrapper = self._model.predict
else:
- if isinstance(model, nn.DataParallel):
+ if isinstance(self._model, nn.DataParallel):
+ self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:
self._predict_func = self._model.forward
+ self._predict_func_wrapper = self._model.forward
def test(self):
"""开始进行验证,并返回验证结果。
@@ -180,7 +181,7 @@ class Tester(object):
def _data_forward(self, func, x):
"""A forward pass of the model. """
x = _build_args(func, **x)
- y = func(**x)
+ y = self._predict_func_wrapper(**x)
return y
def _format_eval_results(self, results):
diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py
index 6edeb4a0..eabda99c 100644
--- a/fastNLP/core/trainer.py
+++ b/fastNLP/core/trainer.py
@@ -454,7 +454,7 @@ class Trainer(object):
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
- metric_key=metric_key, check_level=check_code_level,
+ metric_key=self.metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码
self.model = _move_model_to_device(model, device=device)
@@ -473,7 +473,7 @@ class Trainer(object):
self.best_dev_step = None
self.best_dev_perf = None
self.n_steps = (len(self.train_data) // self.batch_size + int(
- len(self.train_data) % self.batch_size != 0)) * self.n_epochs
+ len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py
index d26df966..490f9f8f 100644
--- a/fastNLP/core/utils.py
+++ b/fastNLP/core/utils.py
@@ -17,7 +17,6 @@ import numpy as np
import torch
import torch.nn as nn
-
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
@@ -278,6 +277,7 @@ def _move_model_to_device(model, device):
return model
+
def _get_model_device(model):
"""
传入一个nn.Module的模型,获取它所在的device
diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py
index 28f466a8..05d75f43 100644
--- a/fastNLP/io/__init__.py
+++ b/fastNLP/io/__init__.py
@@ -11,21 +11,35 @@
"""
__all__ = [
'EmbedLoader',
-
+
+ 'DataInfo',
'DataSetLoader',
+
'CSVLoader',
'JsonLoader',
'ConllLoader',
- 'SNLILoader',
- 'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
'ModelLoader',
'ModelSaver',
+
+ 'SSTLoader',
+
+ 'MatchingLoader',
+ 'SNLILoader',
+ 'MNLILoader',
+ 'QNLILoader',
+ 'QuoraLoader',
+ 'RTELoader',
]
from .embed_loader import EmbedLoader
-from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \
- SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
+from .base_loader import DataInfo, DataSetLoader
+from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \
+ PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver
+
+from .data_loader.sst import SSTLoader
+from .data_loader.matching import MatchingLoader, SNLILoader, \
+ MNLILoader, QNLILoader, QuoraLoader, RTELoader
diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py
index 465fb7e8..8cff1da1 100644
--- a/fastNLP/io/base_loader.py
+++ b/fastNLP/io/base_loader.py
@@ -10,6 +10,7 @@ from typing import Union, Dict
import os
from ..core.dataset import DataSet
+
class BaseLoader(object):
"""
各个 Loader 的基类,提供了 API 的参考。
@@ -55,8 +56,6 @@ class BaseLoader(object):
return obj
-
-
def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
@@ -115,13 +114,11 @@ class DataInfo:
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
- :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""
- def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
+ def __init__(self, vocabs: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
- self.embeddings = embeddings or {}
self.datasets = datasets or {}
def __repr__(self):
@@ -133,6 +130,7 @@ class DataInfo:
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str
+
class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
@@ -213,7 +211,6 @@ class DataSetLoader:
返回的 :class:`DataInfo` 对象有如下属性:
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- - embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
:param paths: 原始数据读取的路径
diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py
new file mode 100644
index 00000000..6f4dd973
--- /dev/null
+++ b/fastNLP/io/data_loader/__init__.py
@@ -0,0 +1,19 @@
+"""
+用于读数据集的模块, 具体包括:
+
+这些模块的使用方法如下:
+"""
+__all__ = [
+ 'SSTLoader',
+
+ 'MatchingLoader',
+ 'SNLILoader',
+ 'MNLILoader',
+ 'QNLILoader',
+ 'QuoraLoader',
+ 'RTELoader',
+]
+
+from .sst import SSTLoader
+from .matching import MatchingLoader, SNLILoader, \
+ MNLILoader, QNLILoader, QuoraLoader, RTELoader
diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py
new file mode 100644
index 00000000..3d131bcb
--- /dev/null
+++ b/fastNLP/io/data_loader/matching.py
@@ -0,0 +1,430 @@
+import os
+
+from typing import Union, Dict
+
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+from ..base_loader import DataInfo, DataSetLoader
+from ..dataset_loader import JsonLoader, CSVLoader
+from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
+from ...modules.encoder._bert import BertTokenizer
+
+
+class MatchingLoader(DataSetLoader):
+ """
+ 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
+
+ 读取Matching任务的数据集
+
+ :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
+ """
+
+ def __init__(self, paths: dict=None):
+ self.paths = paths
+
+ def _load(self, path):
+ """
+ :param str path: 待读取数据集的路径名
+ :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子
+ 的原始字符串文本,第三个为标签
+ """
+ raise NotImplementedError
+
+ def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
+ to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
+ cut_text: int = None, get_index=True, auto_pad_length: int=None,
+ auto_pad_token: str='', set_input: Union[list, str, bool]=True,
+ set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
+ """
+ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
+ 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
+ 对应的全路径文件名。
+ :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
+ 这个数据集的名字,如果不定义则默认为train。
+ :param bool to_lower: 是否将文本自动转为小写。默认值为False。
+ :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
+ 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
+ attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
+ :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
+ :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
+ :param bool get_index: 是否需要根据词表将文本转为index
+ :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
+ :param str auto_pad_token: 自动pad的内容
+ :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
+ 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
+ 于此同时其他field不会被设置为input。默认值为True。
+ :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
+ :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。
+ 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
+ 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
+ :return:
+ """
+ if isinstance(set_input, str):
+ set_input = [set_input]
+ if isinstance(set_target, str):
+ set_target = [set_target]
+ if isinstance(set_input, bool):
+ auto_set_input = set_input
+ else:
+ auto_set_input = False
+ if isinstance(set_target, bool):
+ auto_set_target = set_target
+ else:
+ auto_set_target = False
+ if isinstance(paths, str):
+ if os.path.isdir(paths):
+ path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()}
+ else:
+ path = {dataset_name if dataset_name is not None else 'train': paths}
+ else:
+ path = paths
+
+ data_info = DataInfo()
+ for data_name in path.keys():
+ data_info.datasets[data_name] = self._load(path[data_name])
+
+ for data_name, data_set in data_info.datasets.items():
+ if auto_set_input:
+ data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
+ if auto_set_target:
+ if Const.TARGET in data_set.get_field_names():
+ data_set.set_target(Const.TARGET)
+
+ if to_lower:
+ for data_name, data_set in data_info.datasets.items():
+ data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
+ is_input=auto_set_input)
+ data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1),
+ is_input=auto_set_input)
+
+ if bert_tokenizer is not None:
+ if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
+ PRETRAIN_URL = _get_base_url('bert')
+ model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
+ model_url = PRETRAIN_URL + model_name
+ model_dir = cached_path(model_url)
+ # 检查是否存在
+ elif os.path.isdir(bert_tokenizer):
+ model_dir = bert_tokenizer
+ else:
+ raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")
+
+ words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
+ with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ words_vocab.add_word_lst(lines)
+ words_vocab.build_vocab()
+
+ tokenizer = BertTokenizer.from_pretrained(model_dir)
+
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields,
+ is_input=auto_set_input)
+
+ if isinstance(concat, bool):
+ concat = 'default' if concat else None
+ if concat is not None:
+ if isinstance(concat, str):
+ CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
+ 'default': ['', '', '', '']}
+ if concat.lower() in CONCAT_MAP:
+ concat = CONCAT_MAP[concat]
+ else:
+ concat = 4 * [concat]
+ assert len(concat) == 4, \
+ f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
+ f'the end of first sentence, the begin of second sentence, and the end of second' \
+ f'sentence. Your input is {concat}'
+
+ for data_name, data_set in data_info.datasets.items():
+ data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] +
+ x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT)
+ data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT,
+ is_input=auto_set_input)
+
+ if seq_len_type is not None:
+ if seq_len_type == 'seq_len': #
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: len(x[fields]),
+ new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
+ is_input=auto_set_input)
+ elif seq_len_type == 'mask':
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: [1] * len(x[fields]),
+ new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
+ is_input=auto_set_input)
+ elif seq_len_type == 'bert':
+ for data_name, data_set in data_info.datasets.items():
+ if Const.INPUT not in data_set.get_field_names():
+ raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: '
+ f'got {data_set.get_field_names()}')
+ data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
+ new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input)
+ data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
+ new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
+
+ if auto_pad_length is not None:
+ cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
+
+ if cut_text is not None:
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')):
+ data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields,
+ is_input=auto_set_input)
+
+ data_set_list = [d for n, d in data_info.datasets.items()]
+ assert len(data_set_list) > 0, f'There are NO data sets in data info!'
+
+ if bert_tokenizer is None:
+ words_vocab = Vocabulary(padding=auto_pad_token)
+ words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
+ field_name=[n for n in data_set_list[0].get_field_names()
+ if (Const.INPUT in n)],
+ no_create_entry_dataset=[d for n, d in data_info.datasets.items()
+ if 'train' not in n])
+ target_vocab = Vocabulary(padding=None, unknown=None)
+ target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
+ field_name=Const.TARGET)
+ data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}
+
+ if get_index:
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
+ is_input=auto_set_input)
+
+ if Const.TARGET in data_set.get_field_names():
+ data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
+ is_input=auto_set_input, is_target=auto_set_target)
+
+ if auto_pad_length is not None:
+ if seq_len_type == 'seq_len':
+ raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
+ f'so the seq_len_type cannot be `{seq_len_type}`!')
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
+ (auto_pad_length - len(x[fields])), new_field_name=fields,
+ is_input=auto_set_input)
+ elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
+ data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
+ new_field_name=fields, is_input=auto_set_input)
+
+ for data_name, data_set in data_info.datasets.items():
+ if isinstance(set_input, list):
+ data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
+ if isinstance(set_target, list):
+ data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
+
+ return data_info
+
+
+class SNLILoader(MatchingLoader, JsonLoader):
+ """
+ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
+
+ 读取SNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
+ """
+
+ def __init__(self, paths: dict=None):
+ fields = {
+ 'sentence1_binary_parse': Const.INPUTS(0),
+ 'sentence2_binary_parse': Const.INPUTS(1),
+ 'gold_label': Const.TARGET,
+ }
+ paths = paths if paths is not None else {
+ 'train': 'snli_1.0_train.jsonl',
+ 'dev': 'snli_1.0_dev.jsonl',
+ 'test': 'snli_1.0_test.jsonl'}
+ MatchingLoader.__init__(self, paths=paths)
+ JsonLoader.__init__(self, fields=fields)
+
+ def _load(self, path):
+ ds = JsonLoader._load(self, path)
+
+ parentheses_table = str.maketrans({'(': None, ')': None})
+
+ ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(0))
+ ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(1))
+ ds.drop(lambda x: x[Const.TARGET] == '-')
+ return ds
+
+
+class RTELoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
+
+ 读取RTE数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv' # test set has not label
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ self.fields = {
+ 'sentence1': Const.INPUTS(0),
+ 'sentence2': Const.INPUTS(1),
+ 'label': Const.TARGET,
+ }
+ CSVLoader.__init__(self, sep='\t')
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
+ for fields in ds.get_all_fields():
+ if Const.INPUT in fields:
+ ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
+
+ return ds
+
+
+class QNLILoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
+
+ 读取QNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv' # test set has not label
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ self.fields = {
+ 'question': Const.INPUTS(0),
+ 'sentence': Const.INPUTS(1),
+ 'label': Const.TARGET,
+ }
+ CSVLoader.__init__(self, sep='\t')
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
+ for fields in ds.get_all_fields():
+ if Const.INPUT in fields:
+ ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
+
+ return ds
+
+
+class MNLILoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev_matched': 'dev_matched.tsv',
+ 'dev_mismatched': 'dev_mismatched.tsv',
+ 'test_matched': 'test_matched.tsv',
+ 'test_mismatched': 'test_mismatched.tsv',
+ # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
+ # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
+
+ # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ CSVLoader.__init__(self, sep='\t')
+ self.fields = {
+ 'sentence1_binary_parse': Const.INPUTS(0),
+ 'sentence2_binary_parse': Const.INPUTS(1),
+ 'gold_label': Const.TARGET,
+ }
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if k in ds.get_field_names():
+ ds.rename_field(k, v)
+
+ if Const.TARGET in ds.get_field_names():
+ if ds[0][Const.TARGET] == 'hidden':
+ ds.delete_field(Const.TARGET)
+
+ parentheses_table = str.maketrans({'(': None, ')': None})
+
+ ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(0))
+ ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(1))
+ if Const.TARGET in ds.get_field_names():
+ ds.drop(lambda x: x[Const.TARGET] == '-')
+ return ds
+
+
+class QuoraLoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv',
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+ return ds
diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py
index 1e1b8bef..021a79b7 100644
--- a/fastNLP/io/data_loader/sst.py
+++ b/fastNLP/io/data_loader/sst.py
@@ -1,11 +1,14 @@
from typing import Iterable
from nltk import Tree
+import spacy
from ..base_loader import DataInfo, DataSetLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
from ..embed_loader import EmbeddingOption, EmbedLoader
+spacy.prefer_gpu()
+sptk = spacy.load('en')
class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
@@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader):
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
- return [(t.leaves(), t.label()) for t in tree.subtrees()]
- return [(tree.leaves(), tree.label())]
+ return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ]
+ return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())]
def process(self,
paths,
diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py
index 558fe20e..2881e6e9 100644
--- a/fastNLP/io/dataset_loader.py
+++ b/fastNLP/io/dataset_loader.py
@@ -16,8 +16,6 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
- 'SNLILoader',
- 'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader, DataInfo
-from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
else:
instance = Instance(words=sent_words)
data_set.append(instance)
- data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
+ data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN)
return data_set
@@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader):
return ds
-class SNLILoader(JsonLoader):
- """
- 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
-
- 读取SNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
- """
-
- def __init__(self):
- fields = {
- 'sentence1_parse': Const.INPUTS(0),
- 'sentence2_parse': Const.INPUTS(1),
- 'gold_label': Const.TARGET,
- }
- super(SNLILoader, self).__init__(fields=fields)
-
- def _load(self, path):
- ds = super(SNLILoader, self)._load(path)
-
- def parse_tree(x):
- t = Tree.fromstring(x)
- return t.leaves()
-
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
- ds.drop(lambda x: x[Const.TARGET] == '-')
- return ds
-
-
class CSVLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`
diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py
index 34b5d7c0..0ae0a319 100644
--- a/fastNLP/io/file_reader.py
+++ b/fastNLP/io/file_reader.py
@@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
continue
- raise ValueError('invalid instance at line: {}'.format(line_idx))
+ raise ValueError('invalid instance ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
@@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
return
- print('invalid instance at line: {}'.format(line_idx))
+ print('invalid instance ends at line: {}'.format(line_idx))
raise e
diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py
index 4846c7fa..fb186ce4 100644
--- a/fastNLP/models/bert.py
+++ b/fastNLP/models/bert.py
@@ -8,35 +8,7 @@ from torch import nn
from .base_model import BaseModel
from ..core.const import Const
from ..modules.encoder import BertModel
-
-
-class BertConfig:
-
- def __init__(
- self,
- vocab_size=30522,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=2,
- initializer_range=0.02
- ):
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_act = hidden_act
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.initializer_range = initializer_range
+from ..modules.encoder._bert import BertConfig
class BertForSequenceClassification(BaseModel):
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
+ @classmethod
+ def from_pretrained(cls, num_labels, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
+ @classmethod
+ def from_pretrained(cls, num_choices, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
+ @classmethod
+ def from_pretrained(cls, num_labels, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
+ @classmethod
+ def from_pretrained(cls, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py
index 4c944a54..1aba5a8c 100644
--- a/fastNLP/models/star_transformer.py
+++ b/fastNLP/models/star_transformer.py
@@ -46,7 +46,7 @@ class StarTransEnc(nn.Module):
super(StarTransEnc, self).__init__()
self.embedding = get_embeddings(init_embed)
emb_dim = self.embedding.embedding_dim
- self.emb_fc = nn.Linear(emb_dim, hidden_size)
+ #self.emb_fc = nn.Linear(emb_dim, hidden_size)
self.emb_drop = nn.Dropout(emb_dropout)
self.encoder = StarTransformer(hidden_size=hidden_size,
num_layers=num_layers,
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module):
[batch, hidden] 全局 relay 节点, 详见论文
"""
x = self.embedding(x)
- x = self.emb_fc(self.emb_drop(x))
+ #x = self.emb_fc(self.emb_drop(x))
nodes, relay = self.encoder(x, mask)
return nodes, relay
@@ -205,7 +205,7 @@ class STSeqCls(nn.Module):
max_len=max_len,
emb_dropout=emb_dropout,
dropout=dropout)
- self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
+ self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)
def forward(self, words, seq_len):
"""
diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py
index c1579224..418b3a77 100644
--- a/fastNLP/modules/decoder/mlp.py
+++ b/fastNLP/modules/decoder/mlp.py
@@ -15,7 +15,8 @@ class MLP(nn.Module):
多层感知器
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1
- :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu
+ :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和
+ sigmoid,默认值为relu
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数
:param str initial_method: 参数初始化方式
:param float dropout: dropout概率,默认值为0
diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py
index 254917e5..61a5d7d1 100644
--- a/fastNLP/modules/encoder/_bert.py
+++ b/fastNLP/modules/encoder/_bert.py
@@ -2,7 +2,8 @@
"""
-这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码
+这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
+ 有用,也请引用一下他们。
"""
@@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary
import collections
import unicodedata
-from ...io.file_utils import _get_base_url, cached_path
import numpy as np
from itertools import chain
import copy
@@ -22,9 +22,106 @@ import os
import torch
from torch import nn
import glob
+import sys
CONFIG_FILE = 'bert_config.json'
-MODEL_WEIGHTS = 'pytorch_model.bin'
+
+
+class BertConfig(object):
+ """Configuration class to store the configuration of a `BertModel`.
+ """
+ def __init__(self,
+ vocab_size_or_config_json_file,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12):
+ """Constructs BertConfig.
+
+ Args:
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
+ num_attention_heads: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
+ layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string) in the
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this model might
+ ever be used with. Typically set this to something large just in case
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
+ `BertModel`.
+ initializer_range: The sttdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ layer_norm_eps: The epsilon used by LayerNorm.
+ """
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
+ and isinstance(vocab_size_or_config_json_file, unicode)):
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
+ json_config = json.loads(reader.read())
+ for key, value in json_config.items():
+ self.__dict__[key] = value
+ elif isinstance(vocab_size_or_config_json_file, int):
+ self.vocab_size = vocab_size_or_config_json_file
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ else:
+ raise ValueError("First argument must be either a vocabulary size (int)"
+ "or the path to a pretrained model config file (str)")
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ config = BertConfig(vocab_size_or_config_json_file=-1)
+ for key, value in json_object.items():
+ config.__dict__[key] = value
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `BertConfig` from a json file of parameters."""
+ with open(json_file, "r", encoding='utf-8') as reader:
+ text = reader.read()
+ return cls.from_dict(json.loads(text))
+
+ def __repr__(self):
+ return str(self.to_json_string())
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path):
+ """ Save this instance to a json file."""
+ with open(json_file_path, "w", encoding='utf-8') as writer:
+ writer.write(self.to_json_string())
def gelu(x):
@@ -40,6 +137,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
+ """
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
@@ -53,16 +152,18 @@ class BertLayerNorm(nn.Module):
class BertEmbeddings(nn.Module):
- def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob):
+ """Construct the embeddings from word, position and token_type embeddings.
+ """
+ def __init__(self, config):
super(BertEmbeddings, self).__init__()
- self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
- self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
- self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
- self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
- self.dropout = nn.Dropout(hidden_dropout_prob)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
@@ -82,21 +183,21 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
+ def __init__(self, config):
super(BertSelfAttention, self).__init__()
- if hidden_size % num_attention_heads != 0:
+ if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (hidden_size, num_attention_heads))
- self.num_attention_heads = num_attention_heads
- self.attention_head_size = int(hidden_size / num_attention_heads)
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(hidden_size, self.all_head_size)
- self.key = nn.Linear(hidden_size, self.all_head_size)
- self.value = nn.Linear(hidden_size, self.all_head_size)
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(attention_probs_dropout_prob)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -133,11 +234,11 @@ class BertSelfAttention(nn.Module):
class BertSelfOutput(nn.Module):
- def __init__(self, hidden_size, hidden_dropout_prob):
+ def __init__(self, config):
super(BertSelfOutput, self).__init__()
- self.dense = nn.Linear(hidden_size, hidden_size)
- self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
- self.dropout = nn.Dropout(hidden_dropout_prob)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
@@ -147,10 +248,10 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
+ def __init__(self, config):
super(BertAttention, self).__init__()
- self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
- self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
@@ -159,11 +260,13 @@ class BertAttention(nn.Module):
class BertIntermediate(nn.Module):
- def __init__(self, hidden_size, intermediate_size, hidden_act):
+ def __init__(self, config):
super(BertIntermediate, self).__init__()
- self.dense = nn.Linear(hidden_size, intermediate_size)
- self.intermediate_act_fn = ACT2FN[hidden_act] \
- if isinstance(hidden_act, str) else hidden_act
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@@ -172,11 +275,11 @@ class BertIntermediate(nn.Module):
class BertOutput(nn.Module):
- def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
+ def __init__(self, config):
super(BertOutput, self).__init__()
- self.dense = nn.Linear(intermediate_size, hidden_size)
- self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
- self.dropout = nn.Dropout(hidden_dropout_prob)
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
@@ -186,13 +289,11 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module):
- def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
- intermediate_size, hidden_act):
+ def __init__(self, config):
super(BertLayer, self).__init__()
- self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob,
- hidden_dropout_prob)
- self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act)
- self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
@@ -202,13 +303,10 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module):
- def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob,
- hidden_dropout_prob,
- intermediate_size, hidden_act):
+ def __init__(self, config):
super(BertEncoder, self).__init__()
- layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
- intermediate_size, hidden_act)
- self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)])
+ layer = BertLayer(config)
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
@@ -222,9 +320,9 @@ class BertEncoder(nn.Module):
class BertPooler(nn.Module):
- def __init__(self, hidden_size):
+ def __init__(self, config):
super(BertPooler, self).__init__()
- self.dense = nn.Linear(hidden_size, hidden_size)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
@@ -242,13 +340,19 @@ class BertModel(nn.Module):
如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
- 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
- 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
- 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
- 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
- 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+ 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"
用预训练权重矩阵来建立BERT模型::
@@ -272,34 +376,30 @@ class BertModel(nn.Module):
:param int initializer_range: 初始化权重范围,默认值为0.02
"""
- def __init__(self, vocab_size=30522,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=2,
- initializer_range=0.02):
+ def __init__(self, config, *inputs, **kwargs):
super(BertModel, self).__init__()
- self.hidden_size = hidden_size
- self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings,
- type_vocab_size, hidden_dropout_prob)
- self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads,
- attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size,
- hidden_act)
- self.pooler = BertPooler(hidden_size)
- self.initializer_range = initializer_range
-
+ if not isinstance(config, BertConfig):
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
+ "To create a model from a Google pretrained model use "
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ self.__class__.__name__, self.__class__.__name__
+ ))
+ super(BertModel, self).__init__()
+ self.config = config
+ self.hidden_size = self.config.hidden_size
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def init_bert_weights(self, module):
+ """ Initialize the weights.
+ """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.initializer_range)
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@@ -338,14 +438,19 @@ class BertModel(nn.Module):
return encoded_layers, pooled_output
@classmethod
- def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs):
+ def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs):
+ state_dict = kwargs.get('state_dict', None)
+ kwargs.pop('state_dict', None)
+ cache_dir = kwargs.get('cache_dir', None)
+ kwargs.pop('cache_dir', None)
+ from_tf = kwargs.get('from_tf', False)
+ kwargs.pop('from_tf', None)
# Load config
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE)
- config = json.load(open(config_file, "r"))
- # config = BertConfig.from_json_file(config_file)
+ config = BertConfig.from_json_file(config_file)
# logger.info("Model config {}".format(config))
# Instantiate model.
- model = cls(*inputs, **config, **kwargs)
+ model = cls(config, *inputs, **kwargs)
if state_dict is None:
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin'))
if len(files)==0:
@@ -353,7 +458,7 @@ class BertModel(nn.Module):
elif len(files)>1:
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}")
weights_path = files[0]
- state_dict = torch.load(weights_path)
+ state_dict = torch.load(weights_path, map_location='cpu')
old_keys = []
new_keys = []
@@ -464,6 +569,7 @@ class WordpieceTokenizer(object):
output_tokens.extend(sub_tokens)
return output_tokens
+
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
@@ -594,6 +700,7 @@ class BasicTokenizer(object):
output.append(char)
return "".join(output)
+
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
@@ -840,6 +947,7 @@ class _WordBertModel(nn.Module):
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]]))
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :len(word_pieces_i)+2].fill_(1)
+ # TODO 截掉长度超过的部分。
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py
index c48cb806..005cfe75 100644
--- a/fastNLP/modules/encoder/embedding.py
+++ b/fastNLP/modules/encoder/embedding.py
@@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")
# 读取embedding
- embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
+ embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
normalize=normalize)
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=embedding)
- if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
- words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
- for word, idx in vocab:
- if vocab._is_word_no_create_entry(word) and not hit_flags[idx]:
- words_to_words[idx] = vocab.unknown_idx
- self.words_to_words = words_to_words
self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad
@@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding):
else:
dim = len(parts) - 1
f.seek(0)
- matrix = torch.zeros(len(vocab), dim)
- if init_method is not None:
- init_method(matrix)
- hit_flags = np.zeros(len(vocab), dtype=bool)
+ matrix = {}
+ found_count = 0
for idx, line in enumerate(f, start_idx):
try:
parts = line.strip().split()
@@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab:
index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
- hit_flags[index] = True
+ found_count += 1
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
print("Error occurred at the {} line.".format(idx))
raise e
- found_count = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
- if init_method is None:
- if len(vocab)-found_count>0 and found_count>0: # 有的没找到
- found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()]
- mean = found_vecs.mean(dim=0, keepdim=True)
- std = found_vecs.std(dim=0, keepdim=True)
- unfound_vec_num = np.sum(hit_flags==False)
- unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean
- matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs
+ for word, index in vocab:
+ if index not in matrix and not vocab._is_word_no_create_entry(word):
+ if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
+ matrix[index] = matrix[vocab.unknown_idx]
+ else:
+ matrix[index] = None
+
+ vectors = torch.zeros(len(matrix), dim)
+ if init_method:
+ init_method(vectors)
+ else:
+ nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim))
+
+ if vocab._no_create_word_length>0:
+ if vocab.unknown is None: # 创建一个专门的unknown
+ unknown_idx = len(matrix)
+ vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous()
+ else:
+ unknown_idx = vocab.unknown_idx
+ words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
+ requires_grad=False)
+ for order, (index, vec) in enumerate(matrix.items()):
+ if vec is not None:
+ vectors[order] = vec
+ words_to_words[index] = order
+ self.words_to_words = words_to_words
+ else:
+ for index, vec in matrix.items():
+ if vec is not None:
+ vectors[index] = vec
if normalize:
- matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12)
+ vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12)
- return matrix, hit_flags
+ return vectors
def forward(self, words):
"""
diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py
index 1eec7c13..76b7e922 100644
--- a/fastNLP/modules/encoder/star_transformer.py
+++ b/fastNLP/modules/encoder/star_transformer.py
@@ -35,11 +35,13 @@ class StarTransformer(nn.Module):
self.iters = num_layers
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)])
+ self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
+ self.emb_drop = nn.Dropout(dropout)
self.ring_att = nn.ModuleList(
- [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
+ [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])
self.star_att = nn.ModuleList(
- [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
+ [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])
if max_len is not None:
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module):
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
- if self.pos_emb:
+ if self.pos_emb and False:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P
-
+ embs = norm_func(self.emb_drop, embs)
nodes = embs
relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
r_embs = embs.view(B, H, 1, L)
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
- nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
+ nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
+ #nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
nodes = nodes.masked_fill_(ex_mask, 0)
diff --git a/reproduction/README.md b/reproduction/README.md
index bb21c067..b6f61903 100644
--- a/reproduction/README.md
+++ b/reproduction/README.md
@@ -3,6 +3,8 @@
复现的模型有:
- [Star-Transformer](Star_transformer/)
+- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239)
+- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12)
- ...
# 任务复现
@@ -11,11 +13,11 @@
## Matching (自然语言推理/句子匹配)
-- still in progress
+- [Matching 任务复现](matching)
## Sequence Labeling (序列标注)
-- still in progress
+- [NER](seqence_labelling/ner)
## Coreference resolution (指代消解)
diff --git a/reproduction/Star_transformer/datasets.py b/reproduction/Star_transformer/datasets.py
index a9257fd4..1173d1a0 100644
--- a/reproduction/Star_transformer/datasets.py
+++ b/reproduction/Star_transformer/datasets.py
@@ -2,7 +2,8 @@ import torch
import json
import os
from fastNLP import Vocabulary
-from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader
+from fastNLP.io.dataset_loader import ConllLoader
+from fastNLP.io.data_loader import SSTLoader, SNLILoader
from fastNLP.core import Const as C
import numpy as np
@@ -50,13 +51,15 @@ def load_sst(path, files):
for sub in [True, False, False]]
ds_list = [loader.load(os.path.join(path, fn))
for fn, loader in zip(files, loaders)]
- word_v = Vocabulary(min_freq=2)
+ word_v = Vocabulary(min_freq=0)
tag_v = Vocabulary(unknown=None, padding=None)
for ds in ds_list:
ds.apply(lambda x: [w.lower()
for w in x['words']], new_field_name='words')
- ds_list[0].drop(lambda x: len(x['words']) < 3)
+ #ds_list[0].drop(lambda x: len(x['words']) < 3)
update_v(word_v, ds_list[0], 'words')
+ update_v(word_v, ds_list[1], 'words')
+ update_v(word_v, ds_list[2], 'words')
ds_list[0].apply(lambda x: tag_v.add_word(
x['target']), new_field_name=None)
@@ -151,7 +154,10 @@ class EmbedLoader:
# some words from vocab are missing in pre-trained embedding
# we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)]
- sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
+ #sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
+ # size=(len(vocab) - np.sum(hit_flags), emb_dim))
+ sampled_vectors = np.random.uniform(-0.01, 0.01,
size=(len(vocab) - np.sum(hit_flags), emb_dim))
+
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix
diff --git a/reproduction/Star_transformer/run.sh b/reproduction/Star_transformer/run.sh
index 0972c662..5cd6954b 100644
--- a/reproduction/Star_transformer/run.sh
+++ b/reproduction/Star_transformer/run.sh
@@ -1,5 +1,5 @@
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 &
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 &
-#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log &
+python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log &
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log &
-python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &
+#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &
diff --git a/reproduction/Star_transformer/train.py b/reproduction/Star_transformer/train.py
index 6fb58daf..480748df 100644
--- a/reproduction/Star_transformer/train.py
+++ b/reproduction/Star_transformer/train.py
@@ -1,4 +1,6 @@
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args
+seed = set_rng_seeds(15360)
+print('RNG SEED {}'.format(seed))
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
import torch.nn as nn
import torch
@@ -7,8 +9,8 @@ import fastNLP as FN
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls
from fastNLP.core.const import Const as C
import sys
-sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
-
+#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
+pre_dir = '/home/ec2-user/fast_data/'
g_model_select = {
'pos': STSeqLabel,
@@ -17,8 +19,8 @@ g_model_select = {
'nli': STNLICls,
}
-g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt',
- 'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'}
+g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt',
+ 'zh': pre_dir + 'cc.zh.300.vec'}
g_args = None
g_model_cfg = None
@@ -53,7 +55,7 @@ def get_conll2012_ner():
def get_sst():
- path = '/remote-home/yfshao/workdir/datasets/SST'
+ path = pre_dir + 'sst'
files = ['train.txt', 'dev.txt', 'test.txt']
return load_sst(path, files)
@@ -94,6 +96,7 @@ class MyCallback(FN.core.callback.Callback):
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0)
def on_step_end(self):
+ return
warm_steps = 6000
# learning rate warm-up & decay
if self.step <= warm_steps:
@@ -108,12 +111,11 @@ class MyCallback(FN.core.callback.Callback):
def train():
- seed = set_rng_seeds(1234)
- print('RNG SEED {}'.format(seed))
print('loading data')
ds_list, word_v, tag_v = g_datasets['{}-{}'.format(
g_args.ds, g_args.task)]()
print(ds_list[0][:2])
+ print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2]))
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en')
g_model_cfg['num_cls'] = len(tag_v)
print(g_model_cfg)
@@ -123,11 +125,14 @@ def train():
def init_model(model):
for p in model.parameters():
if p.size(0) != len(word_v):
- nn.init.normal_(p, 0.0, 0.05)
+ if len(p.size())<2:
+ nn.init.constant_(p, 0.0)
+ else:
+ nn.init.normal_(p, 0.0, 0.05)
init_model(model)
train_data = ds_list[0]
- dev_data = ds_list[2]
- test_data = ds_list[1]
+ dev_data = ds_list[1]
+ test_data = ds_list[2]
print(tag_v.word2idx)
if g_args.task in ['pos', 'ner']:
@@ -145,14 +150,26 @@ def train():
}
metric_key, metric = metrics[g_args.task]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
- ex_param = [x for x in model.parameters(
- ) if x.requires_grad and x.size(0) != len(word_v)]
- optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
- {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ]
- trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss,
- batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric,
- metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False,
- device=device, callbacks=[MyCallback()])
+
+ params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)]
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+ print([n for n,p in params])
+ optim_cfg = [
+ #{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
+ {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay},
+ {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay}
+ ]
+
+ print(model)
+ trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data,
+ loss=loss, metrics=metric, metric_key=metric_key,
+ optimizer=torch.optim.Adam(optim_cfg),
+ n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000,
+ device=device,
+ use_tqdm=False, prefetch=False,
+ save_path=g_args.log,
+ sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN),
+ callbacks=[MyCallback()])
trainer.train()
tester = FN.Tester(data=test_data, model=model, metrics=metric,
@@ -195,12 +212,12 @@ def main():
'init_embed': (None, 300),
'num_cls': None,
'hidden_size': g_args.hidden,
- 'num_layers': 4,
+ 'num_layers': 2,
'num_head': g_args.nhead,
'head_dim': g_args.hdim,
'max_len': MAX_LEN,
- 'cls_hidden_size': 600,
- 'emb_dropout': 0.3,
+ 'cls_hidden_size': 200,
+ 'emb_dropout': g_args.drop,
'dropout': g_args.drop,
}
run_select[g_args.mode.lower()]()
diff --git a/reproduction/coreference_resolution/__init__.py b/reproduction/coreference_resolution/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/coreference_resolution/data_load/__init__.py b/reproduction/coreference_resolution/data_load/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py
new file mode 100644
index 00000000..986afcd5
--- /dev/null
+++ b/reproduction/coreference_resolution/data_load/cr_loader.py
@@ -0,0 +1,68 @@
+from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
+from fastNLP.io.file_reader import _read_json
+from fastNLP.core.vocabulary import Vocabulary
+from fastNLP.io.base_loader import DataInfo
+from reproduction.coreference_resolution.model.config import Config
+import reproduction.coreference_resolution.model.preprocess as preprocess
+
+
+class CRLoader(JsonLoader):
+ def __init__(self, fields=None, dropna=False):
+ super().__init__(fields, dropna)
+
+ def _load(self, path):
+ """
+ 加载数据
+ :param path:
+ :return:
+ """
+ dataset = DataSet()
+ for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
+ if self.fields:
+ ins = {self.fields[k]: v for k, v in d.items()}
+ else:
+ ins = d
+ dataset.append(Instance(**ins))
+ return dataset
+
+ def process(self, paths, **kwargs):
+ data_info = DataInfo()
+ for name in ['train', 'test', 'dev']:
+ data_info.datasets[name] = self.load(paths[name])
+
+ config = Config()
+ vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences')
+ vocab.build_vocab()
+ word2id = vocab.word2idx
+
+ char_dict = preprocess.get_char_dict(config.char_path)
+ data_info.vocabs = vocab
+
+ genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
+
+ for name, ds in data_info.datasets.items():
+ ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
+ config.max_sentences, is_train=name=='train')[0],
+ new_field_name='doc_np')
+ ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
+ config.max_sentences, is_train=name=='train')[1],
+ new_field_name='char_index')
+ ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter),
+ config.max_sentences, is_train=name=='train')[2],
+ new_field_name='seq_len')
+ ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'),
+ new_field_name='speaker_ids_np')
+ ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre')
+
+ ds.set_ignore_type('clusters')
+ ds.set_padder('clusters', None)
+ ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len")
+ ds.set_target("clusters")
+
+ # train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False)
+ # train, dev = train_dev.split(343 / (2802 + 343), shuffle=False)
+
+ return data_info
+
+
+
diff --git a/reproduction/coreference_resolution/model/__init__.py b/reproduction/coreference_resolution/model/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/coreference_resolution/model/config.py b/reproduction/coreference_resolution/model/config.py
new file mode 100644
index 00000000..6011257b
--- /dev/null
+++ b/reproduction/coreference_resolution/model/config.py
@@ -0,0 +1,54 @@
+class Config():
+ def __init__(self):
+ self.is_training = True
+ # path
+ self.glove = 'data/glove.840B.300d.txt.filtered'
+ self.turian = 'data/turian.50d.txt'
+ self.train_path = "data/train.english.jsonlines"
+ self.dev_path = "data/dev.english.jsonlines"
+ self.test_path = "data/test.english.jsonlines"
+ self.char_path = "data/char_vocab.english.txt"
+
+ self.cuda = "0"
+ self.max_word = 1500
+ self.epoch = 200
+
+ # config
+ # self.use_glove = True
+ # self.use_turian = True #No
+ self.use_elmo = False
+ self.use_CNN = True
+ self.model_heads = True #Yes
+ self.use_width = True # Yes
+ self.use_distance = True #Yes
+ self.use_metadata = True #Yes
+
+ self.mention_ratio = 0.4
+ self.max_sentences = 50
+ self.span_width = 10
+ self.feature_size = 20 #宽度信息emb的size
+ self.lr = 0.001
+ self.lr_decay = 1e-3
+ self.max_antecedents = 100 # 这个参数在mention detection中没有用
+ self.atten_hidden_size = 150
+ self.mention_hidden_size = 150
+ self.sa_hidden_size = 150
+
+ self.char_emb_size = 8
+ self.filter = [3,4,5]
+
+
+ # decay = 1e-5
+
+ def __str__(self):
+ d = self.__dict__
+ out = 'config==============\n'
+ for i in list(d):
+ out += i+":"
+ out += str(d[i])+"\n"
+ out+="config==============\n"
+ return out
+
+if __name__=="__main__":
+ config = Config()
+ print(config)
diff --git a/reproduction/coreference_resolution/model/metric.py b/reproduction/coreference_resolution/model/metric.py
new file mode 100644
index 00000000..2c924660
--- /dev/null
+++ b/reproduction/coreference_resolution/model/metric.py
@@ -0,0 +1,163 @@
+from fastNLP.core.metrics import MetricBase
+
+import numpy as np
+
+from collections import Counter
+from sklearn.utils.linear_assignment_ import linear_assignment
+
+"""
+Mostly borrowed from https://github.com/clarkkev/deep-coref/blob/master/evaluation.py
+"""
+
+
+
+class CRMetric(MetricBase):
+ def __init__(self):
+ super().__init__()
+ self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
+
+ # TODO 改名为evaluate,输入也
+ def evaluate(self, predicted, mention_to_predicted,clusters):
+ for e in self.evaluators:
+ e.update(predicted,mention_to_predicted, clusters)
+
+ def get_f1(self):
+ return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
+
+ def get_recall(self):
+ return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
+
+ def get_precision(self):
+ return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
+
+ # TODO 原本的getprf
+ def get_metric(self,reset=False):
+ res = {"pre":self.get_precision(), "rec":self.get_recall(), "f":self.get_f1()}
+ self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
+ return res
+
+
+
+
+
+
+class Evaluator():
+ def __init__(self, metric, beta=1):
+ self.p_num = 0
+ self.p_den = 0
+ self.r_num = 0
+ self.r_den = 0
+ self.metric = metric
+ self.beta = beta
+
+ def update(self, predicted,mention_to_predicted,gold):
+ gold = gold[0].tolist()
+ gold = [tuple(tuple(m) for m in gc) for gc in gold]
+ mention_to_gold = {}
+ for gc in gold:
+ for mention in gc:
+ mention_to_gold[mention] = gc
+
+ if self.metric == ceafe:
+ pn, pd, rn, rd = self.metric(predicted, gold)
+ else:
+ pn, pd = self.metric(predicted, mention_to_gold)
+ rn, rd = self.metric(gold, mention_to_predicted)
+ self.p_num += pn
+ self.p_den += pd
+ self.r_num += rn
+ self.r_den += rd
+
+ def get_f1(self):
+ return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
+
+ def get_recall(self):
+ return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
+
+ def get_precision(self):
+ return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
+
+ def get_prf(self):
+ return self.get_precision(), self.get_recall(), self.get_f1()
+
+ def get_counts(self):
+ return self.p_num, self.p_den, self.r_num, self.r_den
+
+
+
+def b_cubed(clusters, mention_to_gold):
+ num, dem = 0, 0
+
+ for c in clusters:
+ if len(c) == 1:
+ continue
+
+ gold_counts = Counter()
+ correct = 0
+ for m in c:
+ if m in mention_to_gold:
+ gold_counts[tuple(mention_to_gold[m])] += 1
+ for c2, count in gold_counts.items():
+ if len(c2) != 1:
+ correct += count * count
+
+ num += correct / float(len(c))
+ dem += len(c)
+
+ return num, dem
+
+
+def muc(clusters, mention_to_gold):
+ tp, p = 0, 0
+ for c in clusters:
+ p += len(c) - 1
+ tp += len(c)
+ linked = set()
+ for m in c:
+ if m in mention_to_gold:
+ linked.add(mention_to_gold[m])
+ else:
+ tp -= 1
+ tp -= len(linked)
+ return tp, p
+
+
+def phi4(c1, c2):
+ return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
+
+
+def ceafe(clusters, gold_clusters):
+ clusters = [c for c in clusters if len(c) != 1]
+ scores = np.zeros((len(gold_clusters), len(clusters)))
+ for i in range(len(gold_clusters)):
+ for j in range(len(clusters)):
+ scores[i, j] = phi4(gold_clusters[i], clusters[j])
+ matching = linear_assignment(-scores)
+ similarity = sum(scores[matching[:, 0], matching[:, 1]])
+ return similarity, len(clusters), similarity, len(gold_clusters)
+
+
+def lea(clusters, mention_to_gold):
+ num, dem = 0, 0
+
+ for c in clusters:
+ if len(c) == 1:
+ continue
+
+ common_links = 0
+ all_links = len(c) * (len(c) - 1) / 2.0
+ for i, m in enumerate(c):
+ if m in mention_to_gold:
+ for m2 in c[i + 1:]:
+ if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
+ common_links += 1
+
+ num += len(c) * common_links / float(all_links)
+ dem += len(c)
+
+ return num, dem
+
+def f1(p_num, p_den, r_num, r_den, beta=1):
+ p = 0 if p_den == 0 else p_num / float(p_den)
+ r = 0 if r_den == 0 else r_num / float(r_den)
+ return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
diff --git a/reproduction/coreference_resolution/model/model_re.py b/reproduction/coreference_resolution/model/model_re.py
new file mode 100644
index 00000000..9dd90ec4
--- /dev/null
+++ b/reproduction/coreference_resolution/model/model_re.py
@@ -0,0 +1,576 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from allennlp.commands.elmo import ElmoEmbedder
+from fastNLP.models.base_model import BaseModel
+from fastNLP.modules.encoder.variational_rnn import VarLSTM
+from reproduction.coreference_resolution.model import preprocess
+from fastNLP.io.embed_loader import EmbedLoader
+import random
+
+# 设置seed
+torch.manual_seed(0) # cpu
+torch.cuda.manual_seed(0) # gpu
+np.random.seed(0) # numpy
+random.seed(0)
+
+
+class ffnn(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super(ffnn, self).__init__()
+
+ self.f = nn.Sequential(
+ # 多少层数
+ nn.Linear(input_size, hidden_size),
+ nn.ReLU(inplace=True),
+ nn.Dropout(p=0.2),
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(inplace=True),
+ nn.Dropout(p=0.2),
+ nn.Linear(hidden_size, output_size)
+ )
+ self.reset_param()
+
+ def reset_param(self):
+ for name, param in self.named_parameters():
+ if param.dim() > 1:
+ nn.init.xavier_normal_(param)
+ # param.data = torch.tensor(np.random.randn(*param.shape)).float()
+ else:
+ nn.init.zeros_(param)
+
+ def forward(self, input):
+ return self.f(input).squeeze()
+
+
+class Model(BaseModel):
+ def __init__(self, vocab, config):
+ word2id = vocab.word2idx
+ super(Model, self).__init__()
+ vocab_num = len(word2id)
+ self.word2id = word2id
+ self.config = config
+ self.char_dict = preprocess.get_char_dict('data/char_vocab.english.txt')
+ self.genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
+ self.device = torch.device("cuda:" + config.cuda)
+
+ self.emb = nn.Embedding(vocab_num, 350)
+
+ emb1 = EmbedLoader().load_with_vocab(config.glove, vocab,normalize=False)
+ emb2 = EmbedLoader().load_with_vocab(config.turian, vocab ,normalize=False)
+ pre_emb = np.concatenate((emb1, emb2), axis=1)
+ pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12)
+
+ if pre_emb is not None:
+ self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float())
+ for param in self.emb.parameters():
+ param.requires_grad = False
+ self.emb_dropout = nn.Dropout(inplace=True)
+
+
+ if config.use_elmo:
+ self.elmo = ElmoEmbedder(options_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json',
+ weight_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5',
+ cuda_device=int(config.cuda))
+ print("elmo load over.")
+ self.elmo_args = torch.randn((3), requires_grad=True).to(self.device)
+
+ self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size)
+ self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3)
+ self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4)
+ self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5)
+
+ self.feature_emb = nn.Embedding(config.span_width, config.feature_size)
+ self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True)
+
+ self.mention_distance_emb = nn.Embedding(10, config.feature_size)
+ self.distance_drop = nn.Dropout(p=0.2, inplace=True)
+
+ self.genre_emb = nn.Embedding(7, config.feature_size)
+ self.speaker_emb = nn.Embedding(2, config.feature_size)
+
+ self.bilstm = VarLSTM(input_size=350+150*config.use_CNN+config.use_elmo*1024,hidden_size=200,bidirectional=True,batch_first=True,hidden_dropout=0.2)
+ # self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True)
+ self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device)
+ self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device)
+ self.bilstm_drop = nn.Dropout(p=0.2, inplace=True)
+
+ self.atten = ffnn(input_size=400, hidden_size=config.atten_hidden_size, output_size=1)
+ self.mention_score = ffnn(input_size=1320, hidden_size=config.mention_hidden_size, output_size=1)
+ self.sa = ffnn(input_size=3980+40*config.use_metadata, hidden_size=config.sa_hidden_size, output_size=1)
+ self.mention_start_np = None
+ self.mention_end_np = None
+
+ def _reorder_lstm(self, word_emb, seq_lens):
+ sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True)
+ seq_lens_re = [seq_lens[i] for i in sort_ind]
+ emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True)
+ packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq, seq_lens_re, batch_first=True)
+
+ h0 = self.h0.repeat(1, len(seq_lens), 1)
+ c0 = self.c0.repeat(1, len(seq_lens), 1)
+ packed_out, final_states = self.bilstm(packed_seq, (h0, c0))
+
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
+ back_map = {ind: i for i, ind in enumerate(sort_ind)}
+ reorder_ind = [back_map[i] for i in range(len(seq_lens_re))]
+ lstm_out = self.reorder_sequence(lstm_out, reorder_ind, batch_first=True)
+ return lstm_out
+
+ def reorder_sequence(self, sequence_emb, order, batch_first=True):
+ """
+ sequence_emb: [T, B, D] if not batch_first
+ order: list of sequence length
+ """
+ batch_dim = 0 if batch_first else 1
+ assert len(order) == sequence_emb.size()[batch_dim]
+
+ order = torch.LongTensor(order)
+ order = order.to(sequence_emb).long()
+
+ sorted_ = sequence_emb.index_select(index=order, dim=batch_dim)
+
+ del order
+ return sorted_
+
+ def flat_lstm(self, lstm_out, seq_lens):
+ batch = lstm_out.shape[0]
+ seq = lstm_out.shape[1]
+ dim = lstm_out.shape[2]
+ l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)]
+ flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device))
+ return flatted
+
+ def potential_mention_index(self, word_index, max_sent_len):
+ # get mention index [3,2]:the first sentence is 3 and secend 2
+ # [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2)
+ potential_mention = []
+ for i in range(len(word_index)):
+ for j in range(i, i + max_sent_len):
+ if (j < len(word_index) and word_index[i] == word_index[j]):
+ potential_mention.append([i, j])
+ return potential_mention
+
+ def get_mention_start_end(self, seq_lens):
+ # 序列长度转换成mention
+ # [3,2] --> [0,0,0,1,1]
+ word_index = [0] * sum(seq_lens)
+ sent_index = 0
+ index = 0
+ for length in seq_lens:
+ for l in range(length):
+ word_index[index] = sent_index
+ index += 1
+ sent_index += 1
+
+ # [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....]
+ mention_id = self.potential_mention_index(word_index, self.config.span_width)
+ mention_start = np.array(mention_id, dtype=int)[:, 0]
+ mention_end = np.array(mention_id, dtype=int)[:, 1]
+ return mention_start, mention_end
+
+ def get_mention_emb(self, flatten_lstm, mention_start, mention_end):
+ mention_start_tensor = torch.from_numpy(mention_start).to(self.device)
+ mention_end_tensor = torch.from_numpy(mention_end).to(self.device)
+ emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor) # [mention_num,embed]
+ emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor) # [mention_num,embed]
+ return emb_start, emb_end
+
+ def get_mask(self, mention_start, mention_end):
+ # big mask for attention
+ mention_num = mention_start.shape[0]
+ mask = np.zeros((mention_num, self.config.span_width)) # [mention_num,span_width]
+ for i in range(mention_num):
+ start = mention_start[i]
+ end = mention_end[i]
+ # 实际上是宽度
+ for j in range(end - start + 1):
+ mask[i][j] = 1
+ mask = torch.from_numpy(mask) # [mention_num,max_mention]
+ # 0-->-inf 1-->0
+ log_mask = torch.log(mask)
+ return log_mask
+
+ def get_mention_index(self, mention_start, max_mention):
+ # TODO 后面可能要改
+ assert len(mention_start.shape) == 1
+ mention_start_tensor = torch.from_numpy(mention_start)
+ num_mention = mention_start_tensor.shape[0]
+ mention_index = mention_start_tensor.expand(max_mention, num_mention).transpose(0,
+ 1) # [num_mention,max_mention]
+ assert mention_index.shape[0] == num_mention
+ assert mention_index.shape[1] == max_mention
+ range_add = torch.arange(0, max_mention).expand(num_mention, max_mention).long() # [num_mention,max_mention]
+ mention_index = mention_index + range_add
+ mention_index = torch.min(mention_index, torch.LongTensor([mention_start[-1]]).expand(num_mention, max_mention))
+ return mention_index.to(self.device)
+
+ def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens):
+ # 排序记录,高分段在前面
+ mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True)
+ preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens))
+ mention_ids = mention_ids[0:preserve_mention_num]
+ mention_score = mention_score[0:preserve_mention_num]
+
+ mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0,
+ index=mention_ids) # [lamda*word_num]
+ mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0,
+ index=mention_ids) # [lamda*word_num]
+ mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0) # [lamda*word_num,emb]
+ assert mention_score.shape[0] == preserve_mention_num
+ assert mention_start_tensor.shape[0] == preserve_mention_num
+ assert mention_end_tensor.shape[0] == preserve_mention_num
+ assert mention_emb.shape[0] == preserve_mention_num
+ # TODO 不交叉没做处理
+
+ # 对start进行再排序,实际位置在前面
+ # TODO 这里只考虑了start没有考虑end
+ mention_start_tensor, temp_index = torch.sort(mention_start_tensor)
+ mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index)
+ mention_emb = mention_emb.index_select(dim=0, index=temp_index)
+ mention_score = mention_score.index_select(dim=0, index=temp_index)
+ return mention_start_tensor, mention_end_tensor, mention_score, mention_emb
+
+ def get_antecedents(self, mention_starts, max_antecedents):
+ num_mention = mention_starts.shape[0]
+ max_antecedents = min(max_antecedents, num_mention)
+ # mention和它是第几个mention之间的对应关系
+ antecedents = np.zeros((num_mention, max_antecedents), dtype=int) # [num_mention,max_an]
+ # 记录长度
+ antecedents_len = [0] * num_mention
+ for i in range(num_mention):
+ ante_count = 0
+ for j in range(max(0, i - max_antecedents), i):
+ antecedents[i, ante_count] = j
+ ante_count += 1
+ # 补位操作
+ for j in range(ante_count, max_antecedents):
+ antecedents[i, j] = 0
+ antecedents_len[i] = ante_count
+ assert antecedents.shape[1] == max_antecedents
+ return antecedents, antecedents_len
+
+ def get_antecedents_score(self, span_represent, mention_score, antecedents, antecedents_len, mention_speakers_ids,
+ genre):
+ num_mention = mention_score.shape[0]
+ max_antecedent = antecedents.shape[1]
+
+ pair_emb = self.get_pair_emb(span_represent, antecedents, mention_speakers_ids, genre) # [span_num,max_ant,emb]
+ antecedent_scores = self.sa(pair_emb)
+ mask01 = self.sequence_mask(antecedents_len, max_antecedent)
+ maskinf = torch.log(mask01).to(self.device)
+ assert maskinf.shape[1] <= max_antecedent
+ assert antecedent_scores.shape[0] == num_mention
+ antecedent_scores = antecedent_scores + maskinf
+ antecedents = torch.from_numpy(antecedents).to(self.device)
+ mention_scoreij = mention_score.unsqueeze(1) + torch.gather(
+ mention_score.unsqueeze(0).expand(num_mention, num_mention), dim=1, index=antecedents)
+ antecedent_scores += mention_scoreij
+
+ antecedent_scores = torch.cat([torch.zeros([mention_score.shape[0], 1]).to(self.device), antecedent_scores],
+ 1) # [num_mentions, max_ant + 1]
+ return antecedent_scores
+
+ ##############################
+ def distance_bin(self, mention_distance):
+ bins = torch.zeros(mention_distance.size()).byte().to(self.device)
+ rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]]
+ for t, k in enumerate(rg):
+ i, j = k[0], k[1]
+ b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device)
+ m1 = torch.ge(mention_distance, b)
+ e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device)
+ m2 = torch.le(mention_distance, e)
+ bins = bins + (t + 1) * (m1 & m2)
+ return bins.long()
+
+ def get_distance_emb(self, antecedents_tensor):
+ num_mention = antecedents_tensor.shape[0]
+ max_ant = antecedents_tensor.shape[1]
+
+ assert max_ant <= self.config.max_antecedents
+ source = torch.arange(0, num_mention).expand(max_ant, num_mention).transpose(0,1).to(self.device) # [num_mention,max_ant]
+ mention_distance = source - antecedents_tensor
+ mention_distance_bin = self.distance_bin(mention_distance)
+ distance_emb = self.mention_distance_emb(mention_distance_bin)
+ distance_emb = self.distance_drop(distance_emb)
+ return distance_emb
+
+ def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre):
+ emb_dim = span_emb.shape[1]
+ num_span = span_emb.shape[0]
+ max_ant = antecedents.shape[1]
+ assert span_emb.shape[0] == antecedents.shape[0]
+ antecedents = torch.from_numpy(antecedents).to(self.device)
+
+ # [num_span,max_ant,emb]
+ antecedent_emb = torch.gather(span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim), dim=1,
+ index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim))
+ # [num_span,max_ant,emb]
+ target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim))
+ target_emb_tiled = target_emb_tiled.transpose(0, 1)
+
+ similarity_emb = antecedent_emb * target_emb_tiled
+
+ pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb]
+
+ # get speakers and genre
+ if self.config.use_metadata:
+ antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(num_span, num_span).gather(dim=1,
+ index=antecedents)
+ same_speaker = torch.eq(mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant),
+ antecedent_speaker_ids) # [num_mention,max_ant]
+ speaker_embedding = self.speaker_emb(same_speaker.long().to(self.device)) # [mention_num.max_ant,emb]
+ genre_embedding = self.genre_emb(
+ torch.LongTensor([genre]).expand(num_span, max_ant).to(self.device)) # [mention_num,max_ant,emb]
+ pair_emb_list.append(speaker_embedding)
+ pair_emb_list.append(genre_embedding)
+
+ # get distance emb
+ if self.config.use_distance:
+ distance_emb = self.get_distance_emb(antecedents)
+ pair_emb_list.append(distance_emb)
+
+ pair_emb = torch.cat(pair_emb_list, 2)
+ return pair_emb
+
+ def sequence_mask(self, len_list, max_len):
+ x = np.zeros((len(len_list), max_len))
+ for i in range(len(len_list)):
+ l = len_list[i]
+ for j in range(l):
+ x[i][j] = 1
+ return torch.from_numpy(x).float()
+
+ def logsumexp(self, value, dim=None, keepdim=False):
+ """Numerically stable implementation of the operation
+
+ value.exp().sum(dim, keepdim).log()
+ """
+ # TODO: torch.max(value, dim=None) threw an error at time of writing
+ if dim is not None:
+ m, _ = torch.max(value, dim=dim, keepdim=True)
+ value0 = value - m
+ if keepdim is False:
+ m = m.squeeze(dim)
+ return m + torch.log(torch.sum(torch.exp(value0),
+ dim=dim, keepdim=keepdim))
+ else:
+ m = torch.max(value)
+ sum_exp = torch.sum(torch.exp(value - m))
+
+ return m + torch.log(sum_exp)
+
+ def softmax_loss(self, antecedent_scores, antecedent_labels):
+ antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(self.device)
+ gold_scores = antecedent_scores + torch.log(antecedent_labels.float()) # [num_mentions, max_ant + 1]
+ marginalized_gold_scores = self.logsumexp(gold_scores, 1) # [num_mentions]
+ log_norm = self.logsumexp(antecedent_scores, 1) # [num_mentions]
+ return torch.sum(log_norm - marginalized_gold_scores) # [num_mentions]reduce_logsumexp
+
+ def get_predicted_antecedents(self, antecedents, antecedent_scores):
+ predicted_antecedents = []
+ for i, index in enumerate(np.argmax(antecedent_scores.detach(), axis=1) - 1):
+ if index < 0:
+ predicted_antecedents.append(-1)
+ else:
+ predicted_antecedents.append(antecedents[i, index])
+ return predicted_antecedents
+
+ def get_predicted_clusters(self, mention_starts, mention_ends, predicted_antecedents):
+ mention_to_predicted = {}
+ predicted_clusters = []
+ for i, predicted_index in enumerate(predicted_antecedents):
+ if predicted_index < 0:
+ continue
+ assert i > predicted_index
+ predicted_antecedent = (int(mention_starts[predicted_index]), int(mention_ends[predicted_index]))
+ if predicted_antecedent in mention_to_predicted:
+ predicted_cluster = mention_to_predicted[predicted_antecedent]
+ else:
+ predicted_cluster = len(predicted_clusters)
+ predicted_clusters.append([predicted_antecedent])
+ mention_to_predicted[predicted_antecedent] = predicted_cluster
+
+ mention = (int(mention_starts[i]), int(mention_ends[i]))
+ predicted_clusters[predicted_cluster].append(mention)
+ mention_to_predicted[mention] = predicted_cluster
+
+ predicted_clusters = [tuple(pc) for pc in predicted_clusters]
+ mention_to_predicted = {m: predicted_clusters[i] for m, i in mention_to_predicted.items()}
+
+ return predicted_clusters, mention_to_predicted
+
+ def evaluate_coref(self, mention_starts, mention_ends, predicted_antecedents, gold_clusters, evaluator):
+ gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters]
+ mention_to_gold = {}
+ for gc in gold_clusters:
+ for mention in gc:
+ mention_to_gold[mention] = gc
+ predicted_clusters, mention_to_predicted = self.get_predicted_clusters(mention_starts, mention_ends,
+ predicted_antecedents)
+ evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
+ return predicted_clusters
+
+
+ def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
+ """
+ 实际输入都是tensor
+ :param sentences: 句子,被fastNLP转化成了numpy,
+ :param doc_np: 被fastNLP转化成了Tensor
+ :param speaker_ids_np: 被fastNLP转化成了Tensor
+ :param genre: 被fastNLP转化成了Tensor
+ :param char_index: 被fastNLP转化成了Tensor
+ :param seq_len: 被fastNLP转化成了Tensor
+ :return:
+ """
+ # change for fastNLP
+ sentences = sentences[0].tolist()
+ doc_tensor = doc_np[0]
+ speakers_tensor = speaker_ids_np[0]
+ genre = genre[0].item()
+ char_index = char_index[0]
+ seq_len = seq_len[0].cpu().numpy()
+
+ # 类型
+
+ # doc_tensor = torch.from_numpy(doc_np).to(self.device)
+ # speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device)
+ mention_emb_list = []
+
+ word_emb = self.emb(doc_tensor)
+ word_emb_list = [word_emb]
+ if self.config.use_CNN:
+ # [batch, length, char_length, char_dim]
+ char = self.char_emb(char_index)
+ char_size = char.size()
+ # first transform to [batch *length, char_length, char_dim]
+ # then transpose to [batch * length, char_dim, char_length]
+ char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
+
+ # put into cnn [batch*length, char_filters, char_length]
+ # then put into maxpooling [batch * length, char_filters]
+ char_over_cnn, _ = self.conv1(char).max(dim=2)
+ # reshape to [batch, length, char_filters]
+ char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1)
+ word_emb_list.append(char_over_cnn)
+
+ char_over_cnn, _ = self.conv2(char).max(dim=2)
+ char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1)
+ word_emb_list.append(char_over_cnn)
+
+ char_over_cnn, _ = self.conv3(char).max(dim=2)
+ char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1)
+ word_emb_list.append(char_over_cnn)
+
+ # word_emb = torch.cat(word_emb_list, dim=2)
+
+ # use elmo or not
+ if self.config.use_elmo:
+ # 如果确实被截断了
+ if doc_tensor.shape[0] == 50 and len(sentences) > 50:
+ sentences = sentences[0:50]
+ elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(sentences)
+ elmo_embedding = elmo_embedding.to(
+ self.device) # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024]
+ elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \
+ self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2]
+ word_emb_list.append(elmo_embedding)
+ # print(word_emb_list[0].shape)
+ # print(word_emb_list[1].shape)
+ # print(word_emb_list[2].shape)
+ # print(word_emb_list[3].shape)
+ # print(word_emb_list[4].shape)
+
+ word_emb = torch.cat(word_emb_list, dim=2)
+
+ word_emb = self.emb_dropout(word_emb)
+ # word_emb_elmo = self.emb_dropout(word_emb_elmo)
+ lstm_out = self._reorder_lstm(word_emb, seq_len)
+ flatten_lstm = self.flat_lstm(lstm_out, seq_len) # [word_num,emb]
+ flatten_lstm = self.bilstm_drop(flatten_lstm)
+ # TODO 没有按照论文写
+ flatten_word_emb = self.flat_lstm(word_emb, seq_len) # [word_num,emb]
+
+ mention_start, mention_end = self.get_mention_start_end(seq_len) # [mention_num]
+ self.mention_start_np = mention_start # [mention_num] np
+ self.mention_end_np = mention_end
+ mention_num = mention_start.shape[0]
+ emb_start, emb_end = self.get_mention_emb(flatten_lstm, mention_start, mention_end) # [mention_num,emb]
+
+ # list
+ mention_emb_list.append(emb_start)
+ mention_emb_list.append(emb_end)
+
+ if self.config.use_width:
+ mention_width_index = mention_end - mention_start
+ mention_width_tensor = torch.from_numpy(mention_width_index).to(self.device) # [mention_num]
+ mention_width_emb = self.feature_emb(mention_width_tensor)
+ mention_width_emb = self.feature_emb_dropout(mention_width_emb)
+ mention_emb_list.append(mention_width_emb)
+
+ if self.config.model_heads:
+ mention_index = self.get_mention_index(mention_start, self.config.span_width) # [mention_num,max_mention]
+ log_mask_tensor = self.get_mask(mention_start, mention_end).float().to(
+ self.device) # [mention_num,max_mention]
+ alpha = self.atten(flatten_lstm).to(self.device) # [word_num]
+
+ # 得到attention
+ mention_head_score = torch.gather(alpha.expand(mention_num, -1), 1,
+ mention_index).float().to(self.device) # [mention_num,max_mention]
+ mention_attention = F.softmax(mention_head_score + log_mask_tensor, dim=1) # [mention_num,max_mention]
+
+ # TODO flatte lstm
+ word_num = flatten_lstm.shape[0]
+ lstm_emb = flatten_lstm.shape[1]
+ emb_num = flatten_word_emb.shape[1]
+
+ # [num_mentions, max_mention_width, emb]
+ mention_text_emb = torch.gather(
+ flatten_word_emb.unsqueeze(1).expand(word_num, self.config.span_width, emb_num),
+ 0, mention_index.unsqueeze(2).expand(mention_num, self.config.span_width,
+ emb_num))
+ # [mention_num,emb]
+ mention_head_emb = torch.sum(
+ mention_attention.unsqueeze(2).expand(mention_num, self.config.span_width, emb_num) * mention_text_emb,
+ dim=1)
+ mention_emb_list.append(mention_head_emb)
+
+ candidate_mention_emb = torch.cat(mention_emb_list, 1) # [candidate_mention_num,emb]
+ candidate_mention_score = self.mention_score(candidate_mention_emb) # [candidate_mention_num]
+
+ antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (None, None, None, None)
+ mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \
+ self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len)
+ mention_speakers_ids = speakers_tensor.index_select(dim=0, index=mention_start_tensor) # num_mention
+
+ antecedents, antecedents_len = self.get_antecedents(mention_start_tensor, self.config.max_antecedents)
+ antecedent_scores = self.get_antecedents_score(mention_emb, mention_score, antecedents, antecedents_len,
+ mention_speakers_ids, genre)
+
+ ans = {"candidate_mention_score": candidate_mention_score, "antecedent_scores": antecedent_scores,
+ "antecedents": antecedents, "mention_start_tensor": mention_start_tensor,
+ "mention_end_tensor": mention_end_tensor}
+
+ return ans
+
+ def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
+ ans = self(sentences,
+ doc_np,
+ speaker_ids_np,
+ genre,
+ char_index,
+ seq_len)
+
+ predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"])
+ predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"],
+ ans["mention_end_tensor"],
+ predicted_antecedents)
+
+ return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted}
+
+
+if __name__ == '__main__':
+ pass
diff --git a/reproduction/coreference_resolution/model/preprocess.py b/reproduction/coreference_resolution/model/preprocess.py
new file mode 100644
index 00000000..d97fcb4d
--- /dev/null
+++ b/reproduction/coreference_resolution/model/preprocess.py
@@ -0,0 +1,225 @@
+import json
+import numpy as np
+from . import util
+import collections
+
+def load(path):
+ """
+ load the file from jsonline
+ :param path:
+ :return: examples with many example(dict): {"clusters":[[[mention],[mention]],[another cluster]],
+ "doc_key":"str","speakers":[[,,,],[]...],"sentence":[[][]]}
+ """
+ with open(path) as f:
+ train_examples = [json.loads(jsonline) for jsonline in f.readlines()]
+ return train_examples
+
+def get_vocab():
+ """
+ 从所有的句子中得到最终的字典,被main调用,不止是train,还有dev和test
+ :param examples:
+ :return: word2id & id2word
+ """
+ word2id = {'PAD':0,'UNK':1}
+ id2word = {0:'PAD',1:'UNK'}
+ index = 2
+ data = [load("../data/train.english.jsonlines"),load("../data/dev.english.jsonlines"),load("../data/test.english.jsonlines")]
+ for examples in data:
+ for example in examples:
+ for sent in example["sentences"]:
+ for word in sent:
+ if(word not in word2id):
+ word2id[word]=index
+ id2word[index] = word
+ index += 1
+ return word2id,id2word
+
+def normalize(v):
+ norm = np.linalg.norm(v)
+ if norm > 0:
+ return v / norm
+ else:
+ return v
+
+# 加载glove得到embedding
+def get_emb(id2word,embedding_size):
+ glove_oov = 0
+ turian_oov = 0
+ both = 0
+ glove_emb_path = "../data/glove.840B.300d.txt.filtered"
+ turian_emb_path = "../data/turian.50d.txt"
+ word_num = len(id2word)
+ emb = np.zeros((word_num,embedding_size))
+ glove_emb_dict = util.load_embedding_dict(glove_emb_path,300,"txt")
+ turian_emb_dict = util.load_embedding_dict(turian_emb_path,50,"txt")
+ for i in range(word_num):
+ if id2word[i] in glove_emb_dict:
+ word_embedding = glove_emb_dict.get(id2word[i])
+ emb[i][0:300] = np.array(word_embedding)
+ else:
+ # print(id2word[i])
+ glove_oov += 1
+ if id2word[i] in turian_emb_dict:
+ word_embedding = turian_emb_dict.get(id2word[i])
+ emb[i][300:350] = np.array(word_embedding)
+ else:
+ # print(id2word[i])
+ turian_oov += 1
+ if id2word[i] not in glove_emb_dict and id2word[i] not in turian_emb_dict:
+ both += 1
+ emb[i] = normalize(emb[i])
+ print("embedding num:"+str(word_num))
+ print("glove num:"+str(glove_oov))
+ print("glove oov rate:"+str(glove_oov/word_num))
+ print("turian num:"+str(turian_oov))
+ print("turian oov rate:"+str(turian_oov/word_num))
+ print("both num:"+str(both))
+ return emb
+
+
+def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train):
+ max_len = 0
+ max_word_length = 0
+ docvex = []
+ length = []
+ if is_train:
+ sent_num = min(max_sentences,len(doc))
+ else:
+ sent_num = len(doc)
+
+ for i in range(sent_num):
+ sent = doc[i]
+ length.append(len(sent))
+ if (len(sent) > max_len):
+ max_len = len(sent)
+ sent_vec =[]
+ for j,word in enumerate(sent):
+ if len(word)>max_word_length:
+ max_word_length = len(word)
+ if word in word2id:
+ sent_vec.append(word2id[word])
+ else:
+ sent_vec.append(word2id["UNK"])
+ docvex.append(sent_vec)
+
+ char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int)
+ for i in range(sent_num):
+ sent = doc[i]
+ for j,word in enumerate(sent):
+ char_index[i, j, :len(word)] = [char_dict[c] for c in word]
+
+ return docvex,char_index,length,max_len
+
+# TODO 修改了接口,确认所有该修改的地方都修改好
+def doc2numpy(doc,word2id,chardict,max_filter,max_sentences,is_train):
+ docvec, char_index, length, max_len = _doc2vec(doc,word2id,chardict,max_filter,max_sentences,is_train)
+ assert max(length) == max_len
+ assert char_index.shape[0]==len(length)
+ assert char_index.shape[1]==max_len
+ doc_np = np.zeros((len(docvec), max_len), int)
+ for i in range(len(docvec)):
+ for j in range(len(docvec[i])):
+ doc_np[i][j] = docvec[i][j]
+ return doc_np,char_index,length
+
+# TODO 没有测试
+def speaker2numpy(speakers_raw,max_sentences,is_train):
+ if is_train and len(speakers_raw)> max_sentences:
+ speakers_raw = speakers_raw[0:max_sentences]
+ speakers = flatten(speakers_raw)
+ speaker_dict = {s: i for i, s in enumerate(set(speakers))}
+ speaker_ids = np.array([speaker_dict[s] for s in speakers])
+ return speaker_ids
+
+
+def flat_cluster(clusters):
+ flatted = []
+ for cluster in clusters:
+ for item in cluster:
+ flatted.append(item)
+ return flatted
+
+def get_right_mention(clusters,mention_start_np,mention_end_np):
+ flatted = flat_cluster(clusters)
+ cluster_num = len(flatted)
+ mention_num = mention_start_np.shape[0]
+ right_mention = np.zeros(mention_num,dtype=int)
+ for i in range(mention_num):
+ if [mention_start_np[i],mention_end_np[i]] in flatted:
+ right_mention[i]=1
+ return right_mention,cluster_num
+
+def handle_cluster(clusters):
+ gold_mentions = sorted(tuple(m) for m in flatten(clusters))
+ gold_mention_map = {m: i for i, m in enumerate(gold_mentions)}
+ cluster_ids = np.zeros(len(gold_mentions), dtype=int)
+ for cluster_id, cluster in enumerate(clusters):
+ for mention in cluster:
+ cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id
+ gold_starts, gold_ends = tensorize_mentions(gold_mentions)
+ return cluster_ids, gold_starts, gold_ends
+
+# 展平
+def flatten(l):
+ return [item for sublist in l for item in sublist]
+
+# 把mention分成start end
+def tensorize_mentions(mentions):
+ if len(mentions) > 0:
+ starts, ends = zip(*mentions)
+ else:
+ starts, ends = [], []
+ return np.array(starts), np.array(ends)
+
+def get_char_dict(path):
+ vocab = [""]
+ with open(path) as f:
+ vocab.extend(c.strip() for c in f.readlines())
+ char_dict = collections.defaultdict(int)
+ char_dict.update({c: i for i, c in enumerate(vocab)})
+ return char_dict
+
+def get_labels(clusters,mention_starts,mention_ends,max_antecedents):
+ cluster_ids, gold_starts, gold_ends = handle_cluster(clusters)
+ num_mention = mention_starts.shape[0]
+ num_gold = gold_starts.shape[0]
+ max_antecedents = min(max_antecedents, num_mention)
+ mention_indices = {}
+
+ for i in range(num_mention):
+ mention_indices[(mention_starts[i].detach().item(), mention_ends[i].detach().item())] = i
+ # 用来记录哪些mention是对的,-1表示错误,正数代表这个mention实际上对应哪个gold cluster的id
+ mention_cluster_ids = [-1] * num_mention
+ # test
+ right_mention_count = 0
+ for i in range(num_gold):
+ right_mention = mention_indices.get((gold_starts[i], gold_ends[i]))
+ if (right_mention != None):
+ right_mention_count += 1
+ mention_cluster_ids[right_mention] = cluster_ids[i]
+
+ # i j 是否属于同一个cluster
+ labels = np.zeros((num_mention, max_antecedents + 1), dtype=bool) # [num_mention,max_an+1]
+ for i in range(num_mention):
+ ante_count = 0
+ null_label = True
+ for j in range(max(0, i - max_antecedents), i):
+ if (mention_cluster_ids[i] >= 0 and mention_cluster_ids[i] == mention_cluster_ids[j]):
+ labels[i, ante_count + 1] = True
+ null_label = False
+ else:
+ labels[i, ante_count + 1] = False
+ ante_count += 1
+ for j in range(ante_count, max_antecedents):
+ labels[i, j + 1] = False
+ labels[i, 0] = null_label
+ return labels
+
+# test===========================
+
+
+if __name__=="__main__":
+ word2id,id2word = get_vocab()
+ get_emb(id2word,350)
+
+
diff --git a/reproduction/coreference_resolution/model/softmax_loss.py b/reproduction/coreference_resolution/model/softmax_loss.py
new file mode 100644
index 00000000..c75a31d6
--- /dev/null
+++ b/reproduction/coreference_resolution/model/softmax_loss.py
@@ -0,0 +1,32 @@
+from fastNLP.core.losses import LossBase
+
+from reproduction.coreference_resolution.model.preprocess import get_labels
+from reproduction.coreference_resolution.model.config import Config
+import torch
+
+
+class SoftmaxLoss(LossBase):
+ """
+ 交叉熵loss
+ 允许多标签分类
+ """
+
+ def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None):
+ """
+
+ :param pred:
+ :param target:
+ """
+ super().__init__()
+ self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters,
+ mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor)
+
+ def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor):
+ antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor,
+ Config().max_antecedents)
+
+ antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda))
+ gold_scores = antecedent_scores + torch.log(antecedent_labels.float()).to(torch.device("cuda:" + Config().cuda)) # [num_mentions, max_ant + 1]
+ marginalized_gold_scores = gold_scores.logsumexp(dim=1) # [num_mentions]
+ log_norm = antecedent_scores.logsumexp(dim=1) # [num_mentions]
+ return torch.sum(log_norm - marginalized_gold_scores)
diff --git a/reproduction/coreference_resolution/model/util.py b/reproduction/coreference_resolution/model/util.py
new file mode 100644
index 00000000..42cd09fe
--- /dev/null
+++ b/reproduction/coreference_resolution/model/util.py
@@ -0,0 +1,101 @@
+import os
+import errno
+import collections
+import torch
+import numpy as np
+import pyhocon
+
+
+
+# flatten the list
+def flatten(l):
+ return [item for sublist in l for item in sublist]
+
+
+def get_config(filename):
+ return pyhocon.ConfigFactory.parse_file(filename)
+
+
+# safe make directions
+def mkdirs(path):
+ try:
+ os.makedirs(path)
+ except OSError as exception:
+ if exception.errno != errno.EEXIST:
+ raise
+ return path
+
+
+def load_char_dict(char_vocab_path):
+ vocab = [""]
+ with open(char_vocab_path) as f:
+ vocab.extend(c.strip() for c in f.readlines())
+ char_dict = collections.defaultdict(int)
+ char_dict.update({c: i for i, c in enumerate(vocab)})
+ return char_dict
+
+# 加载embedding
+def load_embedding_dict(embedding_path, embedding_size, embedding_format):
+ print("Loading word embeddings from {}...".format(embedding_path))
+ default_embedding = np.zeros(embedding_size)
+ embedding_dict = collections.defaultdict(lambda: default_embedding)
+ skip_first = embedding_format == "vec"
+ with open(embedding_path) as f:
+ for i, line in enumerate(f.readlines()):
+ if skip_first and i == 0:
+ continue
+ splits = line.split()
+ assert len(splits) == embedding_size + 1
+ word = splits[0]
+ embedding = np.array([float(s) for s in splits[1:]])
+ embedding_dict[word] = embedding
+ print("Done loading word embeddings.")
+ return embedding_dict
+
+
+# safe devide
+def maybe_divide(x, y):
+ return 0 if y == 0 else x / float(y)
+
+
+def shape(x, dim):
+ return x.get_shape()[dim].value or torch.shape(x)[dim]
+
+
+def normalize(v):
+ norm = np.linalg.norm(v)
+ if norm > 0:
+ return v / norm
+ else:
+ return v
+
+
+class RetrievalEvaluator(object):
+ def __init__(self):
+ self._num_correct = 0
+ self._num_gold = 0
+ self._num_predicted = 0
+
+ def update(self, gold_set, predicted_set):
+ self._num_correct += len(gold_set & predicted_set)
+ self._num_gold += len(gold_set)
+ self._num_predicted += len(predicted_set)
+
+ def recall(self):
+ return maybe_divide(self._num_correct, self._num_gold)
+
+ def precision(self):
+ return maybe_divide(self._num_correct, self._num_predicted)
+
+ def metrics(self):
+ recall = self.recall()
+ precision = self.precision()
+ f1 = maybe_divide(2 * recall * precision, precision + recall)
+ return recall, precision, f1
+
+
+
+if __name__=="__main__":
+ print(load_char_dict("../data/char_vocab.english.txt"))
+ embedding_dict = load_embedding_dict("../data/glove.840B.300d.txt.filtered",300,"txt")
+ print("hello")
diff --git a/reproduction/coreference_resolution/readme.md b/reproduction/coreference_resolution/readme.md
new file mode 100644
index 00000000..67d8cdc7
--- /dev/null
+++ b/reproduction/coreference_resolution/readme.md
@@ -0,0 +1,49 @@
+# 共指消解复现
+## 介绍
+Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。
+对于涉及自然语言理解的许多更高级别的NLP任务来说,
+这是一个重要的步骤,例如文档摘要,问题回答和信息提取。
+代码的实现主要基于[ End-to-End Coreference Resolution (Lee et al, 2017)](https://arxiv.org/pdf/1707.07045).
+
+
+## 数据获取与预处理
+论文在[OntoNote5.0](https://allennlp.org/models)数据集上取得了当时的sota结果。
+由于版权问题,本文无法提供数据集的下载,请自行下载。
+原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。
+
+代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。
+处理之后的数据集为json格式,例子:
+```
+{
+ "clusters": [],
+ "doc_key": "nw",
+ "sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]],
+ "speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]]
+}
+```
+
+### embedding 数据集下载
+[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt)
+
+[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip)
+
+
+
+## 运行
+```python
+# 训练代码
+CUDA_VISIBLE_DEVICES=0 python train.py
+# 测试代码
+CUDA_VISIBLE_DEVICES=0 python valid.py
+```
+
+## 结果
+原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。
+其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。
+
+在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。
+
+
+## 问题
+如果您有什么问题或者反馈,请提issue或者邮件联系我:
+yexu_i@qq.com
diff --git a/reproduction/coreference_resolution/test/__init__.py b/reproduction/coreference_resolution/test/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/coreference_resolution/test/test_dataloader.py b/reproduction/coreference_resolution/test/test_dataloader.py
new file mode 100644
index 00000000..0d9dae52
--- /dev/null
+++ b/reproduction/coreference_resolution/test/test_dataloader.py
@@ -0,0 +1,14 @@
+import unittest
+from ..data_load.cr_loader import CRLoader
+
+class Test_CRLoader(unittest.TestCase):
+ def test_cr_loader(self):
+ train_path = 'data/train.english.jsonlines.mini'
+ dev_path = 'data/dev.english.jsonlines.minid'
+ test_path = 'data/test.english.jsonlines'
+ cr = CRLoader()
+ data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path})
+
+ print(data_info.datasets['train'][0])
+ print(data_info.datasets['dev'][0])
+ print(data_info.datasets['test'][0])
diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py
new file mode 100644
index 00000000..a231a575
--- /dev/null
+++ b/reproduction/coreference_resolution/train.py
@@ -0,0 +1,69 @@
+import sys
+sys.path.append('../..')
+
+import torch
+from torch.optim import Adam
+
+from fastNLP.core.callback import Callback, GradientClipCallback
+from fastNLP.core.trainer import Trainer
+
+from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
+from reproduction.coreference_resolution.model.config import Config
+from reproduction.coreference_resolution.model.model_re import Model
+from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
+from reproduction.coreference_resolution.model.metric import CRMetric
+from fastNLP import SequentialSampler
+from fastNLP import cache_results
+
+
+# torch.backends.cudnn.benchmark = False
+# torch.backends.cudnn.deterministic = True
+
+class LRCallback(Callback):
+ def __init__(self, parameters, decay_rate=1e-3):
+ super().__init__()
+ self.paras = parameters
+ self.decay_rate = decay_rate
+
+ def on_step_end(self):
+ if self.step % 100 == 0:
+ for para in self.paras:
+ para['lr'] = para['lr'] * (1 - self.decay_rate)
+
+
+if __name__ == "__main__":
+ config = Config()
+
+ print(config)
+
+ @cache_results('cache.pkl')
+ def cache():
+ cr_train_dev_test = CRLoader()
+
+ data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path,
+ 'test': config.test_path})
+ return data_info
+ data_info = cache()
+ print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
+ "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))
+ # print(data_info)
+ model = Model(data_info.vocabs, config)
+ print(model)
+
+ loss = SoftmaxLoss()
+
+ metric = CRMetric()
+
+ optim = Adam(model.parameters(), lr=config.lr)
+
+ lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay)
+
+ trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"],
+ loss=loss, metrics=metric, check_code_level=-1,sampler=None,
+ batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
+ optimizer=optim,
+ save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
+ callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
+ print()
+
+ trainer.train()
diff --git a/reproduction/coreference_resolution/valid.py b/reproduction/coreference_resolution/valid.py
new file mode 100644
index 00000000..826332c6
--- /dev/null
+++ b/reproduction/coreference_resolution/valid.py
@@ -0,0 +1,24 @@
+import torch
+from reproduction.coreference_resolution.model.config import Config
+from reproduction.coreference_resolution.model.metric import CRMetric
+from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
+from fastNLP import Tester
+import argparse
+
+
+if __name__=='__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--path')
+ args = parser.parse_args()
+
+ cr_loader = CRLoader()
+ config = Config()
+ data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path,
+ 'test': config.test_path})
+ metirc = CRMetric()
+ model = torch.load(args.path)
+ tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
+ tester.test()
+ print('test over')
+
+
diff --git a/reproduction/joint_cws_parse/__init__.py b/reproduction/joint_cws_parse/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/joint_cws_parse/data/__init__.py b/reproduction/joint_cws_parse/data/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py
new file mode 100644
index 00000000..7802ea09
--- /dev/null
+++ b/reproduction/joint_cws_parse/data/data_loader.py
@@ -0,0 +1,284 @@
+
+
+from fastNLP.io.base_loader import DataSetLoader, DataInfo
+from fastNLP.io.dataset_loader import ConllLoader
+import numpy as np
+
+from itertools import chain
+from fastNLP import DataSet, Vocabulary
+from functools import partial
+import os
+from typing import Union, Dict
+from reproduction.utils import check_dataloader_paths
+
+
+class CTBxJointLoader(DataSetLoader):
+ """
+ 文件夹下应该具有以下的文件结构
+ -train.conllx
+ -dev.conllx
+ -test.conllx
+ 每个文件中的内容如下(空格隔开不同的句子, 共有)
+ 1 费孝通 _ NR NR _ 3 nsubjpass _ _
+ 2 被 _ SB SB _ 3 pass _ _
+ 3 授予 _ VV VV _ 0 root _ _
+ 4 麦格赛赛 _ NR NR _ 5 nn _ _
+ 5 奖 _ NN NN _ 3 dobj _ _
+
+ 1 新华社 _ NR NR _ 7 dep _ _
+ 2 马尼拉 _ NR NR _ 7 dep _ _
+ 3 8月 _ NT NT _ 7 dep _ _
+ 4 31日 _ NT NT _ 7 dep _ _
+ ...
+
+ """
+ def __init__(self):
+ self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7])
+
+ def load(self, path:str):
+ """
+ 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容
+ words: list[str]
+ pos_tags: list[str]
+ heads: list[int]
+ labels: list[str]
+
+ :param path:
+ :return:
+ """
+ dataset = self._loader.load(path)
+ dataset.heads.int()
+ return dataset
+
+ def process(self, paths):
+ """
+
+ :param paths:
+ :return:
+ Dataset包含以下的field
+ chars:
+ bigrams:
+ trigrams:
+ pre_chars:
+ pre_bigrams:
+ pre_trigrams:
+ seg_targets:
+ seg_masks:
+ seq_lens:
+ char_labels:
+ char_heads:
+ gold_word_pairs:
+ seg_targets:
+ seg_masks:
+ char_labels:
+ char_heads:
+ pun_masks:
+ gold_label_word_pairs:
+ """
+ paths = check_dataloader_paths(paths)
+ data = DataInfo()
+
+ for name, path in paths.items():
+ dataset = self.load(path)
+ data.datasets[name] = dataset
+
+ char_labels_vocab = Vocabulary(padding=None, unknown=None)
+
+ def process(dataset, char_label_vocab):
+ dataset.apply(add_word_lst, new_field_name='word_lst')
+ dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars')
+ dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams')
+ dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams')
+ dataset.apply(add_char_heads, new_field_name='char_heads')
+ dataset.apply(add_char_labels, new_field_name='char_labels')
+ dataset.apply(add_segs, new_field_name='seg_targets')
+ dataset.apply(add_mask, new_field_name='seg_masks')
+ dataset.add_seq_len('chars', new_field_name='seq_lens')
+ dataset.apply(add_pun_masks, new_field_name='pun_masks')
+ if len(char_label_vocab.word_count)==0:
+ char_label_vocab.from_dataset(dataset, field_name='char_labels')
+ char_label_vocab.index_dataset(dataset, field_name='char_labels')
+ new_dataset = add_root(dataset)
+ new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True)
+ global add_label_word_pairs
+ add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab)
+ new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True)
+
+ new_dataset.set_pad_val('char_labels', -1)
+ new_dataset.set_pad_val('char_heads', -1)
+
+ return new_dataset
+
+ for name in list(paths.keys()):
+ dataset = data.datasets[name]
+ dataset = process(dataset, char_labels_vocab)
+ data.datasets[name] = dataset
+
+ data.vocabs['char_labels'] = char_labels_vocab
+
+ char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars')
+ bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams')
+ trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams')
+
+ for name in ['chars', 'bigrams', 'trigrams']:
+ vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values()))
+ vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name)
+ data.vocabs['pre_{}'.format(name)] = vocab
+
+ for name, vocab in zip(['chars', 'bigrams', 'trigrams'],
+ [char_vocab, bigram_vocab, trigram_vocab]):
+ vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name)
+ data.vocabs[name] = vocab
+
+ for name, dataset in data.datasets.items():
+ dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars',
+ 'pre_bigrams', 'pre_trigrams')
+ dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels',
+ 'char_heads',
+ 'pun_masks', 'gold_label_word_pairs')
+
+ return data
+
+
+def add_label_word_pairs(instance, label_vocab):
+ # List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
+ word_end_indexes = np.array(list(map(len, instance['word_lst'])))
+ word_end_indexes = np.cumsum(word_end_indexes).tolist()
+ word_end_indexes.insert(0, 0)
+ word_pairs = []
+ labels = instance['labels']
+ pos_tags = instance['pos_tags']
+ for idx, head in enumerate(instance['heads']):
+ if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
+ continue
+ label = label_vocab.to_index(labels[idx])
+ if head==0:
+ word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1]))))
+ else:
+ word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label,
+ (word_end_indexes[idx], word_end_indexes[idx + 1])))
+ return word_pairs
+
+def add_word_pairs(instance):
+ # List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
+ word_end_indexes = np.array(list(map(len, instance['word_lst'])))
+ word_end_indexes = np.cumsum(word_end_indexes).tolist()
+ word_end_indexes.insert(0, 0)
+ word_pairs = []
+ pos_tags = instance['pos_tags']
+ for idx, head in enumerate(instance['heads']):
+ if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
+ continue
+ if head==0:
+ word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1]))))
+ else:
+ word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]),
+ (word_end_indexes[idx], word_end_indexes[idx + 1])))
+ return word_pairs
+
+def add_root(dataset):
+ new_dataset = DataSet()
+ for sample in dataset:
+ chars = ['char_root'] + sample['chars']
+ bigrams = ['bigram_root'] + sample['bigrams']
+ trigrams = ['trigram_root'] + sample['trigrams']
+ seq_lens = sample['seq_lens']+1
+ char_labels = [0] + sample['char_labels']
+ char_heads = [0] + sample['char_heads']
+ sample['chars'] = chars
+ sample['bigrams'] = bigrams
+ sample['trigrams'] = trigrams
+ sample['seq_lens'] = seq_lens
+ sample['char_labels'] = char_labels
+ sample['char_heads'] = char_heads
+ new_dataset.append(sample)
+ return new_dataset
+
+def add_pun_masks(instance):
+ tags = instance['pos_tags']
+ pun_masks = []
+ for word, tag in zip(instance['words'], tags):
+ if tag=='PU':
+ pun_masks.extend([1]*len(word))
+ else:
+ pun_masks.extend([0]*len(word))
+ return pun_masks
+
+def add_word_lst(instance):
+ words = instance['words']
+ word_lst = [list(word) for word in words]
+ return word_lst
+
+def add_bigram(instance):
+ chars = instance['chars']
+ length = len(chars)
+ chars = chars + ['']
+ bigrams = []
+ for i in range(length):
+ bigrams.append(''.join(chars[i:i + 2]))
+ return bigrams
+
+def add_trigram(instance):
+ chars = instance['chars']
+ length = len(chars)
+ chars = chars + [''] * 2
+ trigrams = []
+ for i in range(length):
+ trigrams.append(''.join(chars[i:i + 3]))
+ return trigrams
+
+def add_char_heads(instance):
+ words = instance['word_lst']
+ heads = instance['heads']
+ char_heads = []
+ char_index = 1 # 因此存在root节点所以需要从1开始
+ head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1
+ for word, head in zip(words, heads):
+ char_head = []
+ if len(word)>1:
+ char_head.append(char_index+1)
+ char_index += 1
+ for _ in range(len(word)-2):
+ char_index += 1
+ char_head.append(char_index)
+ char_index += 1
+ char_head.append(head_end_indexes[head-1])
+ char_heads.extend(char_head)
+ return char_heads
+
+def add_char_labels(instance):
+ """
+ 将word_lst中的数据按照下面的方式设置label
+ 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)"..
+ 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label
+ :param instance:
+ :return:
+ """
+ words = instance['word_lst']
+ labels = instance['labels']
+ char_labels = []
+ for word, label in zip(words, labels):
+ for _ in range(len(word)-1):
+ char_labels.append('APP')
+ char_labels.append(label)
+ return char_labels
+
+# add seg_targets
+def add_segs(instance):
+ words = instance['word_lst']
+ segs = [0]*len(instance['chars'])
+ index = 0
+ for word in words:
+ index = index + len(word) - 1
+ segs[index] = len(word)-1
+ index = index + 1
+ return segs
+
+# add target_masks
+def add_mask(instance):
+ words = instance['word_lst']
+ mask = []
+ for word in words:
+ mask.extend([0] * (len(word) - 1))
+ mask.append(1)
+ return mask
diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py
new file mode 100644
index 00000000..1ed5ea2d
--- /dev/null
+++ b/reproduction/joint_cws_parse/models/CharParser.py
@@ -0,0 +1,311 @@
+
+
+
+from fastNLP.models.biaffine_parser import BiaffineParser
+from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from fastNLP.modules.dropout import TimestepDropout
+from fastNLP.modules.encoder.variational_rnn import VarLSTM
+from fastNLP import seq_len_to_mask
+from fastNLP.modules import Embedding
+
+
+def drop_input_independent(word_embeddings, dropout_emb):
+ batch_size, seq_length, _ = word_embeddings.size()
+ word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb)
+ word_masks = torch.bernoulli(word_masks)
+ word_masks = word_masks.unsqueeze(dim=2)
+ word_embeddings = word_embeddings * word_masks
+
+ return word_embeddings
+
+
+class CharBiaffineParser(BiaffineParser):
+ def __init__(self, char_vocab_size,
+ emb_dim,
+ bigram_vocab_size,
+ trigram_vocab_size,
+ num_label,
+ rnn_layers=3,
+ rnn_hidden_size=800, #单向的数量
+ arc_mlp_size=500,
+ label_mlp_size=100,
+ dropout=0.3,
+ encoder='lstm',
+ use_greedy_infer=False,
+ app_index = 0,
+ pre_chars_embed=None,
+ pre_bigrams_embed=None,
+ pre_trigrams_embed=None):
+
+
+ super(BiaffineParser, self).__init__()
+ rnn_out_size = 2 * rnn_hidden_size
+ self.char_embed = Embedding((char_vocab_size, emb_dim))
+ self.bigram_embed = Embedding((bigram_vocab_size, emb_dim))
+ self.trigram_embed = Embedding((trigram_vocab_size, emb_dim))
+ if pre_chars_embed:
+ self.pre_char_embed = Embedding(pre_chars_embed)
+ self.pre_char_embed.requires_grad = False
+ if pre_bigrams_embed:
+ self.pre_bigram_embed = Embedding(pre_bigrams_embed)
+ self.pre_bigram_embed.requires_grad = False
+ if pre_trigrams_embed:
+ self.pre_trigram_embed = Embedding(pre_trigrams_embed)
+ self.pre_trigram_embed.requires_grad = False
+ self.timestep_drop = TimestepDropout(dropout)
+ self.encoder_name = encoder
+
+ if encoder == 'var-lstm':
+ self.encoder = VarLSTM(input_size=emb_dim*3,
+ hidden_size=rnn_hidden_size,
+ num_layers=rnn_layers,
+ bias=True,
+ batch_first=True,
+ input_dropout=dropout,
+ hidden_dropout=dropout,
+ bidirectional=True)
+ elif encoder == 'lstm':
+ self.encoder = nn.LSTM(input_size=emb_dim*3,
+ hidden_size=rnn_hidden_size,
+ num_layers=rnn_layers,
+ bias=True,
+ batch_first=True,
+ dropout=dropout,
+ bidirectional=True)
+
+ else:
+ raise ValueError('unsupported encoder type: {}'.format(encoder))
+
+ self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
+ nn.LeakyReLU(0.1),
+ TimestepDropout(p=dropout),)
+ self.arc_mlp_size = arc_mlp_size
+ self.label_mlp_size = label_mlp_size
+ self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
+ self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
+ self.use_greedy_infer = use_greedy_infer
+ self.reset_parameters()
+ self.dropout = dropout
+
+ self.app_index = app_index
+ self.num_label = num_label
+ if self.app_index != 0:
+ raise ValueError("现在app_index必须等于0")
+
+ def reset_parameters(self):
+ for name, m in self.named_modules():
+ if 'embed' in name:
+ pass
+ elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'):
+ pass
+ else:
+ for p in m.parameters():
+ if len(p.size())>1:
+ nn.init.xavier_normal_(p, gain=0.1)
+ else:
+ nn.init.uniform_(p, -0.1, 0.1)
+
+ def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None,
+ pre_trigrams=None):
+ """
+ max_len是包含root的
+ :param chars: batch_size x max_len
+ :param ngrams: batch_size x max_len*ngram_per_char
+ :param seq_lens: batch_size
+ :param gold_heads: batch_size x max_len
+ :param pre_chars: batch_size x max_len
+ :param pre_ngrams: batch_size x max_len*ngram_per_char
+ :return dict: parsing results
+ arc_pred: [batch_size, seq_len, seq_len]
+ label_pred: [batch_size, seq_len, seq_len]
+ mask: [batch_size, seq_len]
+ head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
+ """
+ # prepare embeddings
+ batch_size, seq_len = chars.shape
+ # print('forward {} {}'.format(batch_size, seq_len))
+
+ # get sequence mask
+ mask = seq_len_to_mask(seq_lens).long()
+
+ chars = self.char_embed(chars) # [N,L] -> [N,L,C_0]
+ bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1]
+ trigrams = self.trigram_embed(trigrams)
+
+ if pre_chars is not None:
+ pre_chars = self.pre_char_embed(pre_chars)
+ # pre_chars = self.pre_char_fc(pre_chars)
+ chars = pre_chars + chars
+ if pre_bigrams is not None:
+ pre_bigrams = self.pre_bigram_embed(pre_bigrams)
+ # pre_bigrams = self.pre_bigram_fc(pre_bigrams)
+ bigrams = bigrams + pre_bigrams
+ if pre_trigrams is not None:
+ pre_trigrams = self.pre_trigram_embed(pre_trigrams)
+ # pre_trigrams = self.pre_trigram_fc(pre_trigrams)
+ trigrams = trigrams + pre_trigrams
+
+ x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C]
+
+ # encoder, extract features
+ if self.training:
+ x = drop_input_independent(x, self.dropout)
+ sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
+ x = x[sort_idx]
+ x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
+ feat, _ = self.encoder(x) # -> [N,L,C]
+ feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
+ _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
+ feat = feat[unsort_idx]
+ feat = self.timestep_drop(feat)
+
+ # for arc biaffine
+ # mlp, reduce dim
+ feat = self.mlp(feat)
+ arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
+ arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
+ label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]
+
+ # biaffine arc classifier
+ arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
+
+ # use gold or predicted arc to predict label
+ if gold_heads is None or not self.training:
+ # use greedy decoding in training
+ if self.training or self.use_greedy_infer:
+ heads = self.greedy_decoder(arc_pred, mask)
+ else:
+ heads = self.mst_decoder(arc_pred, mask)
+ head_pred = heads
+ else:
+ assert self.training # must be training mode
+ if gold_heads is None:
+ heads = self.greedy_decoder(arc_pred, mask)
+ head_pred = heads
+ else:
+ head_pred = None
+ heads = gold_heads
+ # heads: batch_size x max_len
+
+ batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1)
+ label_head = label_head[batch_range, heads].contiguous()
+ label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label]
+ # 这里限制一下,只有当head为下一个时,才能预测app这个label
+ arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\
+ .repeat(batch_size, 1) # batch_size x max_len
+ app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app
+ app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label)
+ app_masks[:, :, 1:] = 0
+ label_pred = label_pred.masked_fill(app_masks, -np.inf)
+
+ res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
+ if head_pred is not None:
+ res_dict['head_pred'] = head_pred
+ return res_dict
+
+ @staticmethod
+ def loss(arc_pred, label_pred, arc_true, label_true, mask):
+ """
+ Compute loss.
+
+ :param arc_pred: [batch_size, seq_len, seq_len]
+ :param label_pred: [batch_size, seq_len, n_tags]
+ :param arc_true: [batch_size, seq_len]
+ :param label_true: [batch_size, seq_len]
+ :param mask: [batch_size, seq_len]
+ :return: loss value
+ """
+
+ batch_size, seq_len, _ = arc_pred.shape
+ flip_mask = (mask == 0)
+ _arc_pred = arc_pred.clone()
+ _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
+
+ arc_true[:, 0].fill_(-1)
+ label_true[:, 0].fill_(-1)
+
+ arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1)
+ label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1)
+
+ return arc_nll + label_nll
+
+ def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams):
+ """
+
+ max_len是包含root的
+
+ :param chars: batch_size x max_len
+ :param ngrams: batch_size x max_len*ngram_per_char
+ :param seq_lens: batch_size
+ :param pre_chars: batch_size x max_len
+ :param pre_ngrams: batch_size x max_len*ngram_per_cha
+ :return:
+ """
+ res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams,
+ pre_trigrams=pre_trigrams, gold_heads=None)
+ output = {}
+ output['arc_pred'] = res.pop('head_pred')
+ _, label_pred = res.pop('label_pred').max(2)
+ output['label_pred'] = label_pred
+ return output
+
+class CharParser(nn.Module):
+ def __init__(self, char_vocab_size,
+ emb_dim,
+ bigram_vocab_size,
+ trigram_vocab_size,
+ num_label,
+ rnn_layers=3,
+ rnn_hidden_size=400, #单向的数量
+ arc_mlp_size=500,
+ label_mlp_size=100,
+ dropout=0.3,
+ encoder='var-lstm',
+ use_greedy_infer=False,
+ app_index = 0,
+ pre_chars_embed=None,
+ pre_bigrams_embed=None,
+ pre_trigrams_embed=None):
+ super().__init__()
+
+ self.parser = CharBiaffineParser(char_vocab_size,
+ emb_dim,
+ bigram_vocab_size,
+ trigram_vocab_size,
+ num_label,
+ rnn_layers,
+ rnn_hidden_size, #单向的数量
+ arc_mlp_size,
+ label_mlp_size,
+ dropout,
+ encoder,
+ use_greedy_infer,
+ app_index,
+ pre_chars_embed=pre_chars_embed,
+ pre_bigrams_embed=pre_bigrams_embed,
+ pre_trigrams_embed=pre_trigrams_embed)
+
+ def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None,
+ pre_trigrams=None):
+ res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars,
+ pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
+ arc_pred = res_dict['arc_pred']
+ label_pred = res_dict['label_pred']
+ masks = res_dict['mask']
+ loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks)
+ return {'loss': loss}
+
+ def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None):
+ res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars,
+ pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
+ output = {}
+ output['head_preds'] = res.pop('head_pred')
+ _, label_pred = res.pop('label_pred').max(2)
+ output['label_preds'] = label_pred
+ return output
diff --git a/reproduction/joint_cws_parse/models/__init__.py b/reproduction/joint_cws_parse/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reproduction/joint_cws_parse/models/callbacks.py b/reproduction/joint_cws_parse/models/callbacks.py
new file mode 100644
index 00000000..8de01109
--- /dev/null
+++ b/reproduction/joint_cws_parse/models/callbacks.py
@@ -0,0 +1,65 @@
+
+from fastNLP.core.callback import Callback
+import torch
+from torch import nn
+
+class OptimizerCallback(Callback):
+ def __init__(self, optimizer, scheduler, update_every=4):
+ super().__init__()
+
+ self._optimizer = optimizer
+ self.scheduler = scheduler
+ self._update_every = update_every
+
+ def on_backward_end(self):
+ if self.step % self._update_every==0:
+ # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5)
+ # self._optimizer.step()
+ self.scheduler.step()
+ # self.model.zero_grad()
+
+
+class DevCallback(Callback):
+ def __init__(self, tester, metric_key='u_f1'):
+ super().__init__()
+ self.tester = tester
+ setattr(tester, 'verbose', 0)
+
+ self.metric_key = metric_key
+
+ self.record_best = False
+ self.best_eval_value = 0
+ self.best_eval_res = None
+
+ self.best_dev_res = None # 存取dev的表现
+
+ def on_valid_begin(self):
+ eval_res = self.tester.test()
+ metric_name = self.tester.metrics[0].__class__.__name__
+ metric_value = eval_res[metric_name][self.metric_key]
+ if metric_value>self.best_eval_value:
+ self.best_eval_value = metric_value
+ self.best_epoch = self.trainer.epoch
+ self.record_best = True
+ self.best_eval_res = eval_res
+ self.test_eval_res = eval_res
+ eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \
+ self.tester._format_eval_results(eval_res)
+ self.pbar.write(eval_str)
+
+ def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
+ if self.record_best:
+ self.best_dev_res = eval_result
+ self.record_best = False
+ if is_better_eval:
+ self.best_dev_res_on_dev = eval_result
+ self.best_test_res_on_dev = self.test_eval_res
+ self.dev_epoch = self.epoch
+
+ def on_train_end(self):
+ print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch,
+ self.tester._format_eval_results(self.best_eval_res),
+ self.tester._format_eval_results(self.best_dev_res)))
+ print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch,
+ self.tester._format_eval_results(self.best_test_res_on_dev),
+ self.tester._format_eval_results(self.best_dev_res_on_dev)))
\ No newline at end of file
diff --git a/reproduction/joint_cws_parse/models/metrics.py b/reproduction/joint_cws_parse/models/metrics.py
new file mode 100644
index 00000000..bf0f0622
--- /dev/null
+++ b/reproduction/joint_cws_parse/models/metrics.py
@@ -0,0 +1,184 @@
+from fastNLP.core.metrics import MetricBase
+from fastNLP.core.utils import seq_len_to_mask
+import torch
+
+
+class SegAppCharParseF1Metric(MetricBase):
+ #
+ def __init__(self, app_index):
+ super().__init__()
+ self.app_index = app_index
+
+ self.parse_head_tp = 0
+ self.parse_label_tp = 0
+ self.rec_tol = 0
+ self.pre_tol = 0
+
+ def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens,
+ pun_masks):
+ """
+
+ max_len是不包含root的character的长度
+ :param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size
+ :param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size
+ :param head_preds: batch_size x max_len
+ :param label_preds: batch_size x max_len
+ :param seq_lens:
+ :param pun_masks: batch_size x
+ :return:
+ """
+ # 去掉root
+ head_preds = head_preds[:, 1:].tolist()
+ label_preds = label_preds[:, 1:].tolist()
+ seq_lens = (seq_lens - 1).tolist()
+
+ # 先解码出words,POS,heads, labels, 对应的character范围
+ for b in range(len(head_preds)):
+ seq_len = seq_lens[b]
+ head_pred = head_preds[b][:seq_len]
+ label_pred = label_preds[b][:seq_len]
+
+ words = [] # 存放[word_start, word_end),相对起始位置,不考虑root
+ heads = []
+ labels = []
+ ranges = [] # 对应该char是第几个word,长度是seq_len+1
+ word_idx = 0
+ word_start_idx = 0
+ for idx, (label, head) in enumerate(zip(label_pred, head_pred)):
+ ranges.append(word_idx)
+ if label == self.app_index:
+ pass
+ else:
+ labels.append(label)
+ heads.append(head)
+ words.append((word_start_idx, idx+1))
+ word_start_idx = idx+1
+ word_idx += 1
+
+ head_dep_tuple = [] # head在前面
+ head_label_dep_tuple = []
+ for idx, head in enumerate(heads):
+ span = words[idx]
+ if span[0]==span[1]-1 and pun_masks[b, span[0]]:
+ continue # exclude punctuations
+ if head == 0:
+ head_dep_tuple.append((('root', words[idx])))
+ head_label_dep_tuple.append(('root', labels[idx], words[idx]))
+ else:
+ head_word_idx = ranges[head-1]
+ head_word_span = words[head_word_idx]
+ head_dep_tuple.append(((head_word_span, words[idx])))
+ head_label_dep_tuple.append((head_word_span, labels[idx], words[idx]))
+
+ gold_head_dep_tuple = set(gold_word_pairs[b])
+ gold_head_label_dep_tuple = set(gold_label_word_pairs[b])
+
+ for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple):
+ if head_dep in gold_head_dep_tuple:
+ self.parse_head_tp += 1
+ if head_label_dep in gold_head_label_dep_tuple:
+ self.parse_label_tp += 1
+ self.pre_tol += len(head_dep_tuple)
+ self.rec_tol += len(gold_head_dep_tuple)
+
+ def get_metric(self, reset=True):
+ u_p = self.parse_head_tp / self.pre_tol
+ u_r = self.parse_head_tp / self.rec_tol
+ u_f = 2*u_p*u_r/(1e-6 + u_p + u_r)
+ l_p = self.parse_label_tp / self.pre_tol
+ l_r = self.parse_label_tp / self.rec_tol
+ l_f = 2*l_p*l_r/(1e-6 + l_p + l_r)
+
+ if reset:
+ self.parse_head_tp = 0
+ self.parse_label_tp = 0
+ self.rec_tol = 0
+ self.pre_tol = 0
+
+ return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4),
+ 'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)}
+
+
+class CWSMetric(MetricBase):
+ def __init__(self, app_index):
+ super().__init__()
+ self.app_index = app_index
+ self.pre = 0
+ self.rec = 0
+ self.tp = 0
+
+ def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens):
+ """
+
+ :param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。
+ :param seg_masks: batch_size x max_len,只有在word结束的地方为1
+ :param label_preds: batch_size x max_len
+ :param seq_lens: batch_size
+ :return:
+ """
+
+ pred_masks = torch.zeros_like(seg_masks)
+ pred_segs = torch.zeros_like(seg_targets)
+
+ seq_lens = (seq_lens - 1).tolist()
+ for idx, label_pred in enumerate(label_preds[:, 1:].tolist()):
+ seq_len = seq_lens[idx]
+ label_pred = label_pred[:seq_len]
+ word_len = 0
+ for l_i, label in enumerate(label_pred):
+ if label==self.app_index and l_i!=len(label_pred)-1:
+ word_len += 1
+ else:
+ pred_segs[idx, l_i] = word_len # 这个词的长度为word_len
+ pred_masks[idx, l_i] = 1
+ word_len = 0
+
+ right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致
+ self.rec += seg_masks.sum().item()
+ self.pre += pred_masks.sum().item()
+ # 且pred和target在同一个地方有值
+ self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item()
+
+ def get_metric(self, reset=True):
+ res = {}
+ res['rec'] = round(self.tp/(self.rec+1e-6), 4)
+ res['pre'] = round(self.tp/(self.pre+1e-6), 4)
+ res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4)
+
+ if reset:
+ self.pre = 0
+ self.rec = 0
+ self.tp = 0
+
+ return res
+
+
+class ParserMetric(MetricBase):
+ def __init__(self, ):
+ super().__init__()
+ self.num_arc = 0
+ self.num_label = 0
+ self.num_sample = 0
+
+ def get_metric(self, reset=True):
+ res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4),
+ 'LAS': round(self.num_label*1.0 / self.num_sample, 4)}
+ if reset:
+ self.num_sample = self.num_label = self.num_arc = 0
+ return res
+
+ def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None):
+ """Evaluate the performance of prediction.
+ """
+ if seq_lens is None:
+ seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte)
+ else:
+ seq_mask = seq_len_to_mask(seq_lens.long(), float=False)
+ # mask out tag
+ seq_mask[:, 0] = 0
+ head_pred_correct = (head_preds == heads).__and__(seq_mask)
+ label_pred_correct = (label_preds == labels).__and__(head_pred_correct)
+ self.num_arc += head_pred_correct.float().sum().item()
+ self.num_label += label_pred_correct.float().sum().item()
+ self.num_sample += seq_mask.sum().item()
+
diff --git a/reproduction/joint_cws_parse/readme.md b/reproduction/joint_cws_parse/readme.md
new file mode 100644
index 00000000..7fe77b47
--- /dev/null
+++ b/reproduction/joint_cws_parse/readme.md
@@ -0,0 +1,16 @@
+Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697)
+
+### 准备数据
+1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'.
+2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。
+3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。
+
+
+### 运行代码
+```
+python train.py
+```
+
+### 其它
+ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节
+learning rate scheduler.
\ No newline at end of file
diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py
new file mode 100644
index 00000000..2f8b0d04
--- /dev/null
+++ b/reproduction/joint_cws_parse/train.py
@@ -0,0 +1,124 @@
+import sys
+sys.path.append('../..')
+
+from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader
+from fastNLP.modules.encoder.embedding import StaticEmbedding
+from torch import nn
+from functools import partial
+from reproduction.joint_cws_parse.models.CharParser import CharParser
+from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric
+from fastNLP import cache_results, BucketSampler, Trainer
+from torch import optim
+from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback
+from torch.optim.lr_scheduler import LambdaLR, StepLR
+from fastNLP import Tester
+from fastNLP import GradientClipCallback, LRScheduler
+import os
+
+def set_random_seed(random_seed=666):
+ import random, numpy, torch
+ random.seed(random_seed)
+ numpy.random.seed(random_seed)
+ torch.cuda.manual_seed(random_seed)
+ torch.random.manual_seed(random_seed)
+
+uniform_init = partial(nn.init.normal_, std=0.02)
+
+###################################################
+# 需要变动的超参放到这里
+lr = 0.002 # 0.01~0.001
+dropout = 0.33 # 0.3~0.6
+weight_decay = 0 # 1e-5, 1e-6, 0
+arc_mlp_size = 500 # 200, 300
+rnn_hidden_size = 400 # 200, 300, 400
+rnn_layers = 3 # 2, 3
+encoder = 'var-lstm' # var-lstm, lstm
+emb_size = 100 # 64 , 100
+label_mlp_size = 100
+
+batch_size = 32
+update_every = 4
+n_epochs = 100
+data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
+vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
+####################################################
+
+set_random_seed(1234)
+device = 0
+
+# @cache_results('caches/{}.pkl'.format(data_name))
+# def get_data():
+data = CTBxJointLoader().process(data_folder)
+
+char_labels_vocab = data.vocabs['char_labels']
+
+pre_chars_vocab = data.vocabs['pre_chars']
+pre_bigrams_vocab = data.vocabs['pre_bigrams']
+pre_trigrams_vocab = data.vocabs['pre_trigrams']
+
+chars_vocab = data.vocabs['chars']
+bigrams_vocab = data.vocabs['bigrams']
+trigrams_vocab = data.vocabs['trigrams']
+
+pre_chars_embed = StaticEmbedding(pre_chars_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std()
+pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std()
+pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std()
+
+ # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data
+
+# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()
+
+print(data)
+model = CharParser(char_vocab_size=len(chars_vocab),
+ emb_dim=emb_size,
+ bigram_vocab_size=len(bigrams_vocab),
+ trigram_vocab_size=len(trigrams_vocab),
+ num_label=len(char_labels_vocab),
+ rnn_layers=rnn_layers,
+ rnn_hidden_size=rnn_hidden_size,
+ arc_mlp_size=arc_mlp_size,
+ label_mlp_size=label_mlp_size,
+ dropout=dropout,
+ encoder=encoder,
+ use_greedy_infer=False,
+ app_index=char_labels_vocab['APP'],
+ pre_chars_embed=pre_chars_embed,
+ pre_bigrams_embed=pre_bigrams_embed,
+ pre_trigrams_embed=pre_trigrams_embed)
+
+metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP'])
+metric2 = CWSMetric(char_labels_vocab['APP'])
+metrics = [metric1, metric2]
+
+optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr,
+ weight_decay=weight_decay, betas=[0.9, 0.9])
+
+sampler = BucketSampler(seq_len_field_name='seq_lens')
+callbacks = []
+# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
+scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
+# optim_callback = OptimizerCallback(optimizer, scheduler, update_every)
+# callbacks.append(optim_callback)
+scheduler_callback = LRScheduler(scheduler)
+callbacks.append(scheduler_callback)
+callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))
+
+tester = Tester(data=data.datasets['test'], model=model, metrics=metrics,
+ batch_size=64, device=device, verbose=0)
+dev_callback = DevCallback(tester)
+callbacks.append(dev_callback)
+
+trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3,
+ validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer,
+ check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True,
+ device=device, callbacks=callbacks, update_every=update_every)
+trainer.train()
\ No newline at end of file
diff --git a/reproduction/matching/README.md b/reproduction/matching/README.md
new file mode 100644
index 00000000..056b0212
--- /dev/null
+++ b/reproduction/matching/README.md
@@ -0,0 +1,100 @@
+# Matching任务模型复现
+这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%).
+
+复现的模型有(按论文发表时间顺序排序):
+- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
+论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844).
+- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py).
+论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf).
+- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
+论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf).
+- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
+论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf).
+- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py).
+论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf).
+
+# 数据集及复现结果汇总
+
+使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果
+
+'\-'表示我们仍未复现或者论文原文没有汇报
+
+model name | SNLI | MNLI | RTE | QNLI | Quora
+:---: | :---: | :---: | :---: | :---: | :---:
+CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - |
+ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - |
+DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 |
+MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 |
+BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - |
+
+*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。
+
+# 数据集复现结果及其他主要模型对比
+## SNLI
+[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/)
+
+Performance on Test set:
+
+model name | ESIM | DIIN | MwAN | [GPT1.0](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | [BERT-Large+SRL](https://arxiv.org/pdf/1809.02794.pdf) | [MT-DNN](https://arxiv.org/pdf/1901.11504.pdf)
+:---: | :---: | :---: | :---: | :---: | :---: | :---:
+__performance__ | 88.0 | 88.0 | 88.3 | 89.9 | 91.3 | 91.6 |
+
+### 基于fastNLP复现的结果
+Performance on Test set:
+
+model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large
+:---: | :---: | :---: | :---: | :---: | :---: | :---:
+__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16
+
+## MNLI
+[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/)
+
+Performance on Test set(matched/mismatched):
+
+model name | ESIM | DIIN | MwAN | GPT1.0 | BERT-Base | MT-DNN
+:---: | :---: | :---: | :---: | :---: | :---: | :---:
+__performance__ | 72.4/72.1 | 78.8/77.8 | 78.5/77.7 | 82.1/81.4 | 84.6/83.4 | 87.9/87.4 |
+
+### 基于fastNLP复现的结果
+Performance on Test set(matched/mismatched):
+
+model name | CNTN | ESIM | DIIN | MwAN | BERT-Base
+:---: | :---: | :---: | :---: | :---: | :---: |
+__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - |
+
+
+## RTE
+
+Still in progress.
+
+## QNLI
+
+### From GLUE baselines
+[Link to GLUE leaderboard](https://gluebenchmark.com/leaderboard)
+
+Performance on Test set:
+#### LSTM-based
+model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo
+:---: | :---: | :---: | :---: | :---: |
+__performance__ | 74.6 | 74.3 | 75.5 | 79.8 |
+
+*这些LSTM-based的baseline是由QNLI官方实现并测试的。
+
+#### Transformer-based
+model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN
+:---: | :---: | :---: | :---: | :---: |
+__performance__ | 87.4 | 90.5 | 92.7 | 96.0 |
+
+
+
+### 基于fastNLP复现的结果
+Performance on __Dev__ set:
+
+model name | CNTN | ESIM | DIIN | MwAN | BERT
+:---: | :---: | :---: | :---: | :---: | :---:
+__performance__ | - | 76.97 | - | 74.6 | -
+
+## Quora
+
+Still in progress.
+
diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py
index 749b16c8..7c32899c 100644
--- a/reproduction/matching/data/MatchingDataLoader.py
+++ b/reproduction/matching/data/MatchingDataLoader.py
@@ -5,8 +5,8 @@ from typing import Union, Dict
from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary
-from fastNLP.io.base_loader import DataInfo
-from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader
+from fastNLP.io.base_loader import DataInfo, DataSetLoader
+from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer
@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader):
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
读取Matching任务的数据集
+
+ :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
"""
def __init__(self, paths: dict=None):
- """
- :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
- """
self.paths = paths
def _load(self, path):
@@ -34,7 +33,8 @@ class MatchingLoader(DataSetLoader):
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
- cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True,
+ cut_text: int = None, get_index=True, auto_pad_length: int=None,
+ auto_pad_token: str='', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
@@ -49,6 +49,8 @@ class MatchingLoader(DataSetLoader):
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
:param bool get_index: 是否需要根据词表将文本转为index
+ :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
+ :param str auto_pad_token: 自动pad的内容
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
于此同时其他field不会被设置为input。默认值为True。
@@ -169,6 +171,9 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
+ if auto_pad_length is not None:
+ cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
+
if cut_text is not None:
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
@@ -180,7 +185,7 @@ class MatchingLoader(DataSetLoader):
assert len(data_set_list) > 0, f'There are NO data sets in data info!'
if bert_tokenizer is None:
- words_vocab = Vocabulary()
+ words_vocab = Vocabulary(padding=auto_pad_token)
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
field_name=[n for n in data_set_list[0].get_field_names()
if (Const.INPUT in n)],
@@ -202,6 +207,20 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
is_input=auto_set_input, is_target=auto_set_target)
+ if auto_pad_length is not None:
+ if seq_len_type == 'seq_len':
+ raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
+ f'so the seq_len_type cannot be `{seq_len_type}`!')
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
+ (auto_pad_length - len(x[fields])), new_field_name=fields,
+ is_input=auto_set_input)
+ elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
+ data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
+ new_field_name=fields, is_input=auto_set_input)
+
for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list):
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
@@ -267,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
- # 'test': 'test.tsv' # test set has not label
+ 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
@@ -281,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
- ds.rename_field(k, v)
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -306,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
- # 'test': 'test.tsv' # test set has not label
+ 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
@@ -320,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
- ds.rename_field(k, v)
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -332,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
- 读取SNLI数据集,读取的DataSet包含fields::
+ 读取MNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
@@ -348,6 +369,10 @@ class MNLILoader(MatchingLoader, CSVLoader):
'dev_mismatched': 'dev_mismatched.tsv',
'test_matched': 'test_matched.tsv',
'test_mismatched': 'test_mismatched.tsv',
+ # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
+ # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
+
+ # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t')
@@ -364,6 +389,10 @@ class MNLILoader(MatchingLoader, CSVLoader):
if k in ds.get_field_names():
ds.rename_field(k, v)
+ if Const.TARGET in ds.get_field_names():
+ if ds[0][Const.TARGET] == 'hidden':
+ ds.delete_field(Const.TARGET)
+
parentheses_table = str.maketrans({'(': None, ')': None})
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
@@ -376,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader):
class QuoraLoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py
new file mode 100644
index 00000000..75112d5a
--- /dev/null
+++ b/reproduction/matching/matching_bert.py
@@ -0,0 +1,102 @@
+import random
+import numpy as np
+import torch
+
+from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam
+
+from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
+ MNLILoader, QNLILoader, QuoraLoader
+from reproduction.matching.model.bert import BertForNLI
+
+
+# define hyper-parameters
+class BERTConfig:
+
+ task = 'snli'
+ batch_size_per_gpu = 6
+ n_epochs = 6
+ lr = 2e-5
+ seq_len_type = 'bert'
+ seed = 42
+ train_dataset_name = 'train'
+ dev_dataset_name = 'dev'
+ test_dataset_name = 'test'
+ save_path = None # 模型存储的位置,None表示不存储模型。
+ bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹
+
+
+arg = BERTConfig()
+
+# set random seed
+random.seed(arg.seed)
+np.random.seed(arg.seed)
+torch.manual_seed(arg.seed)
+
+n_gpu = torch.cuda.device_count()
+if n_gpu > 0:
+ torch.cuda.manual_seed_all(arg.seed)
+
+# load data set
+if arg.task == 'snli':
+ data_info = SNLILoader().process(
+ paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type,
+ bert_tokenizer=arg.bert_dir, cut_text=512,
+ get_index=True, concat='bert',
+ )
+elif arg.task == 'rte':
+ data_info = RTELoader().process(
+ paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type,
+ bert_tokenizer=arg.bert_dir, cut_text=512,
+ get_index=True, concat='bert',
+ )
+elif arg.task == 'qnli':
+ data_info = QNLILoader().process(
+ paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
+ bert_tokenizer=arg.bert_dir, cut_text=512,
+ get_index=True, concat='bert',
+ )
+elif arg.task == 'mnli':
+ data_info = MNLILoader().process(
+ paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
+ bert_tokenizer=arg.bert_dir, cut_text=512,
+ get_index=True, concat='bert',
+ )
+elif arg.task == 'quora':
+ data_info = QuoraLoader().process(
+ paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type,
+ bert_tokenizer=arg.bert_dir, cut_text=512,
+ get_index=True, concat='bert',
+ )
+else:
+ raise RuntimeError(f'NOT support {arg.task} task yet!')
+
+# define model
+model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir)
+
+# define trainer
+trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
+ optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
+ batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
+ n_epochs=arg.n_epochs, print_every=-1,
+ dev_data=data_info.datasets[arg.dev_dataset_name],
+ metrics=AccuracyMetric(), metric_key='acc',
+ device=[i for i in range(torch.cuda.device_count())],
+ check_code_level=-1,
+ save_path=arg.save_path)
+
+# train model
+trainer.train(load_best_model=True)
+
+# define tester
+tester = Tester(
+ data=data_info.datasets[arg.test_dataset_name],
+ model=model,
+ metrics=AccuracyMetric(),
+ batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
+ device=[i for i in range(torch.cuda.device_count())],
+)
+
+# test model
+tester.test()
+
+
diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py
index 3da6141f..d878608f 100644
--- a/reproduction/matching/matching_esim.py
+++ b/reproduction/matching/matching_esim.py
@@ -1,47 +1,103 @@
-import argparse
+import random
+import numpy as np
import torch
+from torch.optim import Adamax
+from torch.optim.lr_scheduler import StepLR
-from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
+from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
+from fastNLP.core.callback import GradientClipCallback, LRScheduler
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
-from reproduction.matching.data.MatchingDataLoader import SNLILoader
+from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
+ MNLILoader, QNLILoader, QuoraLoader
from reproduction.matching.model.esim import ESIMModel
-argument = argparse.ArgumentParser()
-argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove')
-argument.add_argument('--batch-size-per-gpu', type=int, default=128)
-argument.add_argument('--n-epochs', type=int, default=100)
-argument.add_argument('--lr', type=float, default=1e-4)
-argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len')
-argument.add_argument('--save-dir', type=str, default=None)
-arg = argument.parse_args()
-bert_dirs = 'path/to/bert/dir'
+# define hyper-parameters
+class ESIMConfig:
+
+ task = 'snli'
+ embedding = 'glove'
+ batch_size_per_gpu = 196
+ n_epochs = 30
+ lr = 2e-3
+ seq_len_type = 'seq_len'
+ # seq_len表示在process的时候用len(words)来表示长度信息;
+ # mask表示用0/1掩码矩阵来表示长度信息;
+ seed = 42
+ train_dataset_name = 'train'
+ dev_dataset_name = 'dev'
+ test_dataset_name = 'test'
+ save_path = None # 模型存储的位置,None表示不存储模型。
+
+
+arg = ESIMConfig()
+
+# set random seed
+random.seed(arg.seed)
+np.random.seed(arg.seed)
+torch.manual_seed(arg.seed)
+
+n_gpu = torch.cuda.device_count()
+if n_gpu > 0:
+ torch.cuda.manual_seed_all(arg.seed)
# load data set
-data_info = SNLILoader().process(
- paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
- get_index=True, concat=False,
-)
+if arg.task == 'snli':
+ data_info = SNLILoader().process(
+ paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type,
+ get_index=True, concat=False,
+ )
+elif arg.task == 'rte':
+ data_info = RTELoader().process(
+ paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type,
+ get_index=True, concat=False,
+ )
+elif arg.task == 'qnli':
+ data_info = QNLILoader().process(
+ paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
+ get_index=True, concat=False,
+ )
+elif arg.task == 'mnli':
+ data_info = MNLILoader().process(
+ paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
+ get_index=True, concat=False,
+ )
+elif arg.task == 'quora':
+ data_info = QuoraLoader().process(
+ paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type,
+ get_index=True, concat=False,
+ )
+else:
+ raise RuntimeError(f'NOT support {arg.task} task yet!')
# load embedding
if arg.embedding == 'elmo':
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
elif arg.embedding == 'glove':
- embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
+ embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False)
else:
- raise ValueError(f'now we only support elmo or glove embedding for esim model!')
+ raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')
# define model
-model = ESIMModel(embedding)
+model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET]))
+
+# define optimizer and callback
+optimizer = Adamax(lr=arg.lr, params=model.parameters())
+scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍
+
+callbacks = [
+ GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10)
+ LRScheduler(scheduler),
+]
# define trainer
-trainer = Trainer(train_data=data_info.datasets['train'], model=model,
- optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
+trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
+ optimizer=optimizer,
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
- dev_data=data_info.datasets['dev'],
+ dev_data=data_info.datasets[arg.dev_dataset_name],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True)
# define tester
tester = Tester(
- data=data_info.datasets['test'],
+ data=data_info.datasets[arg.test_dataset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py
index d55034e7..187e565d 100644
--- a/reproduction/matching/model/esim.py
+++ b/reproduction/matching/model/esim.py
@@ -81,6 +81,7 @@ class ESIMModel(BaseModel):
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H]
logits = torch.tanh(self.classifier(out))
+ # logits = self.classifier(out)
if target is not None:
loss_fct = CrossEntropyLoss()
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel):
return {Const.OUTPUT: logits}
def predict(self, **kwargs):
- return self.forward(**kwargs)
+ pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1)
+ return {Const.OUTPUT: pred}
# input [batch_size, len , hidden]
# mask [batch_size, len] (111...00)
@@ -127,7 +129,7 @@ class BiRNN(nn.Module):
def forward(self, x, x_mask):
# Sort x
- lengths = x_mask.data.eq(1).long().sum(1).squeeze()
+ lengths = x_mask.data.eq(1).long().sum(1)
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort])
diff --git a/reproduction/seqence_labelling/ner/README.md b/reproduction/seqence_labelling/ner/README.md
new file mode 100644
index 00000000..d42046b0
--- /dev/null
+++ b/reproduction/seqence_labelling/ner/README.md
@@ -0,0 +1,13 @@
+# NER任务模型复现
+这里使用fastNLP复现经典的BiLSTM-CNN的NER任务的模型,旨在达到与论文中相符的性能。
+
+论文链接[Named Entity Recognition with Bidirectional LSTM-CNNs](https://arxiv.org/pdf/1511.08308.pdf)
+
+# 数据集及复现结果汇总
+
+使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道)
+
+model name | Conll2003 | Ontonotes
+:---: | :---: | :---:
+BiLSTM-CNN | 91.17/90.91 | 86.47/86.35 |
+
diff --git a/reproduction/utils.py b/reproduction/utils.py
index 26b2014c..4f0d021e 100644
--- a/reproduction/utils.py
+++ b/reproduction/utils.py
@@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
}
如果paths为不合法的,将直接进行raise相应的错误
- :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt,
- test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
+ :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
+ 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
- train_fp = os.path.join(paths, 'train.txt')
- if not os.path.isfile(train_fp):
- raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
- files = {'train': train_fp}
- for filename in ['dev.txt', 'test.txt']:
- fp = os.path.join(paths, filename)
- if os.path.isfile(fp):
- files[filename.split('.')[0]] = fp
+ filenames = os.listdir(paths)
+ files = {}
+ for filename in filenames:
+ path_pair = None
+ if 'train' in filename:
+ path_pair = ('train', filename)
+ if 'dev' in filename:
+ if path_pair:
+ raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
+ path_pair = ('dev', filename)
+ if 'test' in filename:
+ if path_pair:
+ raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
+ path_pair = ('test', filename)
+ if path_pair:
+ files[path_pair[0]] = os.path.join(paths, path_pair[1])
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
diff --git a/requirements.txt b/requirements.txt
index 7ea8fdac..f8f7a951 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-numpy
-torch>=0.4.0
-tqdm
-nltk
+numpy>=1.14.2
+torch>=1.0.0
+tqdm>=4.28.1
+nltk>=3.4.1
requests
diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py
index 7cff3c12..09ad8c83 100644
--- a/test/io/test_dataset_loader.py
+++ b/test/io/test_dataset_loader.py
@@ -1,7 +1,7 @@
import unittest
import os
-from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader
-from fastNLP.io.dataset_loader import SSTLoader
+from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader
+from fastNLP.io.data_loader import SSTLoader, SNLILoader
from reproduction.text_classification.data.yelpLoader import yelpLoader
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase):
print(info.vocabs)
print(info.datasets)
os.remove(train), os.remove(test)
+
+ def test_import(self):
+ import fastNLP
+ from fastNLP.io import SNLILoader
+ ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
+ get_index=True, seq_len_type='seq_len')
+ assert 'train' in ds.datasets
+ assert len(ds.datasets) == 1
+ assert len(ds.datasets['train']) == 3
diff --git a/test/models/test_bert.py b/test/models/test_bert.py
index 7177f31b..38a16f9b 100644
--- a/test/models/test_bert.py
+++ b/test/models/test_bert.py
@@ -8,8 +8,9 @@ from fastNLP.models.bert import *
class TestBert(unittest.TestCase):
def test_bert_1(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForSequenceClassification(2)
+ model = BertForSequenceClassification(2, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -22,8 +23,9 @@ class TestBert(unittest.TestCase):
def test_bert_2(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForMultipleChoice(2)
+ model = BertForMultipleChoice(2, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -36,8 +38,9 @@ class TestBert(unittest.TestCase):
def test_bert_3(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForTokenClassification(7)
+ model = BertForTokenClassification(7, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -50,8 +53,9 @@ class TestBert(unittest.TestCase):
def test_bert_4(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForQuestionAnswering()
+ model = BertForQuestionAnswering(BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
diff --git a/test/modules/encoder/test_bert.py b/test/modules/encoder/test_bert.py
index 78bcf633..2a799478 100644
--- a/test/modules/encoder/test_bert.py
+++ b/test/modules/encoder/test_bert.py
@@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel
class TestBert(unittest.TestCase):
def test_bert_1(self):
- model = BertModel(vocab_size=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
+ from fastNLP.modules.encoder._bert import BertConfig
+ config = BertConfig(32000)
+ model = BertModel(config)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -18,4 +19,4 @@ class TestBert(unittest.TestCase):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
- self.assertEqual(tuple(pooled_output.shape), (2, 768))
\ No newline at end of file
+ self.assertEqual(tuple(pooled_output.shape), (2, 768))
diff --git a/test/test_tutorials.py b/test/test_tutorials.py
index 87910c3d..6f4a8347 100644
--- a/test/test_tutorials.py
+++ b/test/test_tutorials.py
@@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase):
train_data.rename_field('label', 'label_seq')
test_data.rename_field('label', 'label_seq')
- loss = CrossEntropyLoss(pred="output", target="label_seq")
+ loss = CrossEntropyLoss(target="label_seq")
metric = AccuracyMetric(target="label_seq")
# 实例化Trainer,传入模型和数据,进行训练
@@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase):
# 用train_data训练,在test_data验证
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
- loss=CrossEntropyLoss(pred="output", target="label_seq"),
+ loss=CrossEntropyLoss(target="label_seq"),
metrics=AccuracyMetric(target="label_seq"),
save_path=None,
batch_size=32,