@@ -6,48 +6,59 @@ | |||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | ||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。 | |||||
- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等; | |||||
- 详尽的中文文档以供查阅; | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; | |||||
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; | |||||
- 各种方便的NLP工具,例如预处理embedding加载(包括ELMo和BERT); 中间数据cache等; | |||||
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)以供查阅; | |||||
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | - 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | ||||
- 封装CNNText,Biaffine等模型可供直接使用; | |||||
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用,详细内容见 [reproduction](reproduction) 部分; | |||||
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | - 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | ||||
## 安装指南 | ## 安装指南 | ||||
fastNLP 依赖如下包: | |||||
fastNLP 依赖以下包: | |||||
+ numpy | |||||
+ torch>=0.4.0 | |||||
+ tqdm | |||||
+ nltk | |||||
+ numpy>=1.14.2 | |||||
+ torch>=1.0.0 | |||||
+ tqdm>=4.28.1 | |||||
+ nltk>=3.4.1 | |||||
+ requests | |||||
+ spacy | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。 | |||||
在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装 | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 | |||||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装 | |||||
```shell | ```shell | ||||
pip install fastNLP | pip install fastNLP | ||||
python -m spacy download en | |||||
``` | ``` | ||||
## 参考资源 | |||||
## fastNLP教程 | |||||
- [文档](https://fastnlp.readthedocs.io/zh/latest/) | |||||
- [源码](https://github.com/fastnlp/fastNLP) | |||||
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html) | |||||
- [2. 使用DataSetLoader加载数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html) | |||||
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) | |||||
- [4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_loss_optimizer.html) | |||||
- [5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_datasetiter.html) | |||||
- [6. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_seq_labeling.html) | |||||
- [7. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_modules_models.html) | |||||
- [8. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_metrics.html) | |||||
- [9. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_callback.html) | |||||
## 内置组件 | ## 内置组件 | ||||
大部分用于的 NLP 任务神经网络都可以看做由编码(encoder)、聚合(aggregator)、解码(decoder)三种模块组成。 | |||||
大部分用于的 NLP 任务神经网络都可以看做由编码器(encoder)、解码器(decoder)两种模块组成。 | |||||
![](./docs/source/figures/text_classification.png) | ![](./docs/source/figures/text_classification.png) | ||||
fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 三种模块的功能和常见组件如下: | |||||
fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下: | |||||
<table> | <table> | ||||
<tr> | <tr> | ||||
@@ -57,29 +68,17 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助 | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td> encoder </td> | <td> encoder </td> | ||||
<td> 将输入编码为具有具 有表示能力的向量 </td> | |||||
<td> 将输入编码为具有具有表示能力的向量 </td> | |||||
<td> embedding, RNN, CNN, transformer | <td> embedding, RNN, CNN, transformer | ||||
</tr> | </tr> | ||||
<tr> | |||||
<td> aggregator </td> | |||||
<td> 从多个向量中聚合信息 </td> | |||||
<td> self-attention, max-pooling </td> | |||||
</tr> | |||||
<tr> | <tr> | ||||
<td> decoder </td> | <td> decoder </td> | ||||
<td> 将具有某种表示意义的 向量解码为需要的输出 形式 </td> | |||||
<td> 将具有某种表示意义的向量解码为需要的输出形式 </td> | |||||
<td> MLP, CRF </td> | <td> MLP, CRF </td> | ||||
</tr> | </tr> | ||||
</table> | </table> | ||||
## 完整模型 | |||||
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | |||||
你可以在以下两个地方查看相关信息 | |||||
- [介绍](reproduction/) | |||||
- [源码](fastNLP/models/) | |||||
## 项目结构 | ## 项目结构 | ||||
![](./docs/source/figures/workflow.png) | ![](./docs/source/figures/workflow.png) | ||||
@@ -93,7 +92,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | <td><b> fastNLP.core </b></td> | ||||
<td> 实现了核心功能,包括数据处理组件、训练器、测速器等 </td> | |||||
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -37,7 +37,7 @@ __all__ = [ | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"SQuADMetric", | |||||
"ExtractiveQAMetric", | |||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
@@ -61,3 +61,4 @@ __version__ = '0.4.0' | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
from .io import data_loader |
@@ -21,7 +21,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||||
from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -3,7 +3,6 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"BatchIter", | |||||
"DataSetIter", | "DataSetIter", | ||||
"TorchLoaderIter", | "TorchLoaderIter", | ||||
] | ] | ||||
@@ -50,6 +49,7 @@ class DataSetGetter: | |||||
return len(self.dataset) | return len(self.dataset) | ||||
def collate_fn(self, batch: list): | def collate_fn(self, batch: list): | ||||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||||
batch_x = {n:[] for n in self.inputs.keys()} | batch_x = {n:[] for n in self.inputs.keys()} | ||||
batch_y = {n:[] for n in self.targets.keys()} | batch_y = {n:[] for n in self.targets.keys()} | ||||
indices = [] | indices = [] | ||||
@@ -136,6 +136,31 @@ class BatchIter: | |||||
class DataSetIter(BatchIter): | class DataSetIter(BatchIter): | ||||
""" | |||||
别名::class:`fastNLP.DataSetIter` :class:`fastNLP.core.batch.DataSetIter` | |||||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||||
组成 `x` 和 `y`:: | |||||
batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler()) | |||||
num_batch = len(batch) | |||||
for batch_x, batch_y in batch: | |||||
# do stuff ... | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||||
:param int batch_size: 取出的batch大小 | |||||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||||
Default: ``None`` | |||||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||||
Default: ``False`` | |||||
:param int num_workers: 使用多少个进程来预处理数据 | |||||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||||
:param timeout: | |||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||||
""" | |||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None): | timeout=0, worker_init_fn=None): | ||||
@@ -66,6 +66,8 @@ import os | |||||
import torch | import torch | ||||
from copy import deepcopy | from copy import deepcopy | ||||
import sys | |||||
from .utils import _save_model | |||||
try: | try: | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
@@ -399,10 +401,11 @@ class GradientClipCallback(Callback): | |||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self.parameters is None: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.parameters, self.clip_value) | |||||
if self.step%self.update_every==0: | |||||
if self.parameters is None: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.parameters, self.clip_value) | |||||
class EarlyStopCallback(Callback): | class EarlyStopCallback(Callback): | ||||
@@ -736,6 +739,132 @@ class TensorboardCallback(Callback): | |||||
del self._summary_writer | del self._summary_writer | ||||
class WarmupCallback(Callback): | |||||
""" | |||||
按一定的周期调节Learning rate的大小。 | |||||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||||
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 | |||||
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||||
""" | |||||
def __init__(self, warmup=0.1, schedule='constant'): | |||||
super().__init__() | |||||
self.warmup = max(warmup, 0.) | |||||
self.initial_lrs = [] # 存放param_group的learning rate | |||||
if schedule == 'constant': | |||||
self.get_lr = self._get_constant_lr | |||||
elif schedule == 'linear': | |||||
self.get_lr = self._get_linear_lr | |||||
else: | |||||
raise RuntimeError("Only support 'linear', 'constant'.") | |||||
def _get_constant_lr(self, progress): | |||||
if progress<self.warmup: | |||||
return progress/self.warmup | |||||
return 1 | |||||
def _get_linear_lr(self, progress): | |||||
if progress<self.warmup: | |||||
return progress/self.warmup | |||||
return max((progress - 1.) / (self.warmup - 1.), 0.) | |||||
def on_train_begin(self): | |||||
self.t_steps = (len(self.trainer.train_data) // (self.batch_size*self.update_every) + | |||||
int(len(self.trainer.train_data) % (self.batch_size*self.update_every)!= 0)) * self.n_epochs | |||||
if self.warmup>1: | |||||
self.warmup = self.warmup/self.t_steps | |||||
self.t_steps = max(2, self.t_steps) # 不能小于2 | |||||
# 获取param_group的初始learning rate | |||||
for group in self.optimizer.param_groups: | |||||
self.initial_lrs.append(group['lr']) | |||||
def on_backward_end(self): | |||||
if self.step%self.update_every==0: | |||||
progress = (self.step/self.update_every)/self.t_steps | |||||
for lr, group in zip(self.initial_lrs, self.optimizer.param_groups): | |||||
group['lr'] = lr * self.get_lr(progress) | |||||
class SaveModelCallback(Callback): | |||||
""" | |||||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | |||||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型 | |||||
-save_dir | |||||
-2019-07-03-15-06-36 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||||
-2019-07-03-15-10-00 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||||
:param bool only_param: 是否只保存模型d饿权重。 | |||||
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. | |||||
""" | |||||
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | |||||
super().__init__() | |||||
if not os.path.isdir(save_dir): | |||||
raise IsADirectoryError("{} is not a directory.".format(save_dir)) | |||||
self.save_dir = save_dir | |||||
if top < 0: | |||||
self.top = sys.maxsize | |||||
else: | |||||
self.top = top | |||||
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 | |||||
self.only_param = only_param | |||||
self.save_on_exception = save_on_exception | |||||
def on_train_begin(self): | |||||
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
metric_value = list(eval_result.values())[0][metric_key] | |||||
self._save_this_model(metric_value) | |||||
def _insert_into_ordered_save_models(self, pair): | |||||
# pair:(metric_value, model_name) | |||||
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 | |||||
index = -1 | |||||
for _pair in self._ordered_save_models: | |||||
if _pair[0]>=pair[0] and self.trainer.increase_better: | |||||
break | |||||
if not self.trainer.increase_better and _pair[0]<=pair[0]: | |||||
break | |||||
index += 1 | |||||
save_pair = None | |||||
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1): | |||||
save_pair = pair | |||||
self._ordered_save_models.insert(index+1, pair) | |||||
delete_pair = None | |||||
if len(self._ordered_save_models)>self.top: | |||||
delete_pair = self._ordered_save_models.pop(0) | |||||
return save_pair, delete_pair | |||||
def _save_this_model(self, metric_value): | |||||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | |||||
if save_pair: | |||||
try: | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
except Exception as e: | |||||
print(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||||
if delete_pair: | |||||
try: | |||||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||||
if os.path.exists(delete_model_path): | |||||
os.remove(delete_model_path) | |||||
except Exception as e: | |||||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||||
def on_exception(self, exception): | |||||
if self.save_on_exception: | |||||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
class CallbackException(BaseException): | class CallbackException(BaseException): | ||||
""" | """ | ||||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | 当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | ||||
@@ -20,6 +20,7 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..core.const import Const | |||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _CheckRes | from .utils import _CheckRes | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
class LossBase(object): | class LossBase(object): | ||||
""" | """ | ||||
所有loss的基类。如果想了解其中的原理,请查看源码。 | 所有loss的基类。如果想了解其中的原理,请查看源码。 | ||||
@@ -95,22 +97,7 @@ class LossBase(object): | |||||
# if func_spect.varargs: | # if func_spect.varargs: | ||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | ||||
# f"positional argument.).") | # f"positional argument.).") | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
:param dict pred_dict: 模型的forward函数返回的dict | :param dict pred_dict: 模型的forward函数返回的dict | ||||
@@ -118,11 +105,7 @@ class LossBase(object): | |||||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param: | |||||
loss = self.get_loss(**fast_param) | |||||
return loss | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and _param_map | # 1. check consistence between signature and _param_map | ||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
@@ -212,7 +195,6 @@ class LossFunc(LossBase): | |||||
if not isinstance(key_map, dict): | if not isinstance(key_map, dict): | ||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | ||||
self._init_param_map(key_map, **kwargs) | self._init_param_map(key_map, **kwargs) | ||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
@@ -226,7 +208,7 @@ class CrossEntropyLoss(LossBase): | |||||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | ||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | ||||
传入seq_len. | 传入seq_len. | ||||
:param str reduction: 支持'elementwise_mean'和'sum'. | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
Example:: | Example:: | ||||
@@ -234,16 +216,16 @@ class CrossEntropyLoss(LossBase): | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='elementwise_mean'): | |||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.padding_idx = padding_idx | self.padding_idx = padding_idx | ||||
assert reduction in ('elementwise_mean', 'sum') | |||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | self.reduction = reduction | ||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if pred.dim()>2: | |||||
if pred.size(1)!=target.size(1): | |||||
if pred.dim() > 2: | |||||
if pred.size(1) != target.size(1): | |||||
pred = pred.transpose(1, 2) | pred = pred.transpose(1, 2) | ||||
pred = pred.reshape(-1, pred.size(-1)) | pred = pred.reshape(-1, pred.size(-1)) | ||||
target = target.reshape(-1) | target = target.reshape(-1) | ||||
@@ -263,15 +245,18 @@ class L1Loss(LossBase): | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.l1_loss(input=pred, target=target) | |||||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
@@ -282,14 +267,17 @@ class BCELoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.binary_cross_entropy(input=pred, target=target) | |||||
return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
@@ -300,14 +288,20 @@ class NLLLoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||||
传入seq_len. | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
self.ignore_idx = ignore_idx | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.nll_loss(input=pred, target=target) | |||||
return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
@@ -319,7 +313,7 @@ class LossInForward(LossBase): | |||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | :param str loss_key: 在forward函数中loss的键名,默认为loss | ||||
""" | """ | ||||
def __init__(self, loss_key='loss'): | |||||
def __init__(self, loss_key=Const.LOSS): | |||||
super().__init__() | super().__init__() | ||||
if not isinstance(loss_key, str): | if not isinstance(loss_key, str): | ||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | ||||
@@ -6,7 +6,7 @@ __all__ = [ | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"SQuADMetric" | |||||
"ExtractiveQAMetric" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -24,6 +24,7 @@ from .utils import seq_len_to_mask | |||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | """ | ||||
所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | ||||
@@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): | |||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
class SQuADMetric(MetricBase): | |||||
class ExtractiveQAMetric(MetricBase): | |||||
r""" | r""" | ||||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||||
SQuAD数据集metric | |||||
抽取式QA(如SQuAD)的metric. | |||||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | ||||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | ||||
@@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): | |||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
super(SQuADMetric, self).__init__() | |||||
super(ExtractiveQAMetric, self).__init__() | |||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
@@ -16,6 +16,7 @@ from collections import Counter, namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from typing import List | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -162,6 +163,30 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
return wrapper_ | return wrapper_ | ||||
def _save_model(model, model_name, save_dir, only_param=False): | |||||
""" 存储不含有显卡信息的state_dict或model | |||||
:param model: | |||||
:param model_name: | |||||
:param save_dir: 保存的directory | |||||
:param only_param: | |||||
:return: | |||||
""" | |||||
model_path = os.path.join(save_dir, model_name) | |||||
if not os.path.isdir(save_dir): | |||||
os.makedirs(save_dir, exist_ok=True) | |||||
if isinstance(model, nn.DataParallel): | |||||
model = model.module | |||||
if only_param: | |||||
state_dict = model.state_dict() | |||||
for key in state_dict: | |||||
state_dict[key] = state_dict[key].cpu() | |||||
torch.save(state_dict, model_path) | |||||
else: | |||||
_model_device = _get_model_device(model) | |||||
model.cpu() | |||||
torch.save(model, model_path) | |||||
model.to(_model_device) | |||||
# def save_pickle(obj, pickle_path, file_name): | # def save_pickle(obj, pickle_path, file_name): | ||||
# """Save an object into a pickle file. | # """Save an object into a pickle file. | ||||
@@ -277,7 +302,6 @@ def _move_model_to_device(model, device): | |||||
return model | return model | ||||
def _get_model_device(model): | def _get_model_device(model): | ||||
""" | """ | ||||
传入一个nn.Module的模型,获取它所在的device | 传入一个nn.Module的模型,获取它所在的device | ||||
@@ -285,7 +309,7 @@ def _get_model_device(model): | |||||
:param model: nn.Module | :param model: nn.Module | ||||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | ||||
""" | """ | ||||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding | |||||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | |||||
assert isinstance(model, nn.Module) | assert isinstance(model, nn.Module) | ||||
parameters = list(model.parameters()) | parameters = list(model.parameters()) | ||||
@@ -712,3 +736,52 @@ class _pseudo_tqdm: | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
del self | del self | ||||
def iob2(tags:List[str])->List[str]: | |||||
""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见 | |||||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||||
:param tags: 需要转换的tags, 需要为大写的BIO标签。 | |||||
""" | |||||
for i, tag in enumerate(tags): | |||||
if tag == "O": | |||||
continue | |||||
split = tag.split("-") | |||||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||||
raise TypeError("The encoding schema is not a valid IOB type.") | |||||
if split[0] == "B": | |||||
continue | |||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
elif tags[i - 1][1:] == tag[1:]: | |||||
continue | |||||
else: # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
return tags | |||||
def iob2bioes(tags:List[str])->List[str]: | |||||
""" | |||||
将iob的tag转换为bioes编码 | |||||
:param tags: List[str]. 编码需要是大写的。 | |||||
:return: | |||||
""" | |||||
new_tags = [] | |||||
for i, tag in enumerate(tags): | |||||
if tag == 'O': | |||||
new_tags.append(tag) | |||||
else: | |||||
split = tag.split('-')[0] | |||||
if split == 'B': | |||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('B-', 'S-')) | |||||
elif split == 'I': | |||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('I-', 'E-')) | |||||
else: | |||||
raise TypeError("Invalid IOB format.") | |||||
return new_tags |
@@ -91,42 +91,84 @@ class Vocabulary(object): | |||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | self.rebuild = True | ||||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | # 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | ||||
self._no_create_word = defaultdict(int) | |||||
self._no_create_word = Counter() | |||||
@_check_build_status | @_check_build_status | ||||
def update(self, word_lst): | |||||
def update(self, word_lst, no_create_entry=False): | |||||
"""依次增加序列中词在词典中的出现频率 | """依次增加序列中词在词典中的出现频率 | ||||
:param list word_lst: a list of strings | :param list word_lst: a list of strings | ||||
""" | |||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
""" | |||||
self._add_no_create_entry(word_lst, no_create_entry) | |||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
@_check_build_status | @_check_build_status | ||||
def add(self, word): | |||||
def add(self, word, no_create_entry=False): | |||||
""" | """ | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
""" | |||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
""" | |||||
self._add_no_create_entry(word, no_create_entry) | |||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
def _add_no_create_entry(self, word, no_create_entry): | |||||
""" | |||||
在新加入word时,检查_no_create_word的设置。 | |||||
:param str, List[str] word: | |||||
:param bool no_create_entry: | |||||
:return: | |||||
""" | |||||
if isinstance(word, str): | |||||
word = [word] | |||||
for w in word: | |||||
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | |||||
self._no_create_word[w] += 1 | |||||
elif not no_create_entry and w in self._no_create_word: | |||||
self._no_create_word.pop(w) | |||||
@_check_build_status | @_check_build_status | ||||
def add_word(self, word): | |||||
def add_word(self, word, no_create_entry=False): | |||||
""" | """ | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
""" | |||||
self.add(word) | |||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
""" | |||||
self.add(word, no_create_entry=no_create_entry) | |||||
@_check_build_status | @_check_build_status | ||||
def add_word_lst(self, word_lst): | |||||
def add_word_lst(self, word_lst, no_create_entry=False): | |||||
""" | """ | ||||
依次增加序列中词在词典中的出现频率 | 依次增加序列中词在词典中的出现频率 | ||||
:param list[str] word_lst: 词的序列 | :param list[str] word_lst: 词的序列 | ||||
""" | |||||
self.update(word_lst) | |||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
""" | |||||
self.update(word_lst, no_create_entry=no_create_entry) | |||||
def build_vocab(self): | def build_vocab(self): | ||||
""" | """ | ||||
@@ -136,10 +178,10 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if self.word2idx is None: | if self.word2idx is None: | ||||
self.word2idx = {} | self.word2idx = {} | ||||
if self.padding is not None: | |||||
self.word2idx[self.padding] = len(self.word2idx) | |||||
if self.unknown is not None: | |||||
self.word2idx[self.unknown] = len(self.word2idx) | |||||
if self.padding is not None: | |||||
self.word2idx[self.padding] = len(self.word2idx) | |||||
if self.unknown is not None: | |||||
self.word2idx[self.unknown] = len(self.word2idx) | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
@@ -278,23 +320,17 @@ class Vocabulary(object): | |||||
for fn in field_name: | for fn in field_name: | ||||
field = ins[fn] | field = ins[fn] | ||||
if isinstance(field, str): | if isinstance(field, str): | ||||
if no_create_entry and field not in self.word_count: | |||||
self._no_create_word[field] += 1 | |||||
self.add_word(field) | |||||
self.add_word(field, no_create_entry=no_create_entry) | |||||
elif isinstance(field, (list, np.ndarray)): | elif isinstance(field, (list, np.ndarray)): | ||||
if not isinstance(field[0], (list, np.ndarray)): | if not isinstance(field[0], (list, np.ndarray)): | ||||
for word in field: | for word in field: | ||||
if no_create_entry and word not in self.word_count: | |||||
self._no_create_word[word] += 1 | |||||
self.add_word(word) | |||||
self.add_word(word, no_create_entry=no_create_entry) | |||||
else: | else: | ||||
if isinstance(field[0][0], (list, np.ndarray)): | if isinstance(field[0][0], (list, np.ndarray)): | ||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
for words in field: | for words in field: | ||||
for word in words: | for word in words: | ||||
if no_create_entry and word not in self.word_count: | |||||
self._no_create_word[word] += 1 | |||||
self.add_word(word) | |||||
self.add_word(word, no_create_entry=no_create_entry) | |||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
@@ -11,21 +11,32 @@ | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | 'EmbedLoader', | ||||
'DataBundle', | |||||
'DataSetLoader', | 'DataSetLoader', | ||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | |||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | |||||
'Conll2003Loader', | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'MatchingLoader', | |||||
'PeopleDailyCorpusLoader', | |||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'SST2Loader', | |||||
'MNLILoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .base_loader import DataBundle, DataSetLoader | |||||
from .dataset_loader import CSVLoader, JsonLoader | |||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
from .data_loader import * |
@@ -1,6 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"BaseLoader", | "BaseLoader", | ||||
'DataInfo', | |||||
'DataBundle', | |||||
'DataSetLoader', | 'DataSetLoader', | ||||
] | ] | ||||
@@ -10,6 +10,7 @@ from typing import Union, Dict | |||||
import os | import os | ||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
class BaseLoader(object): | class BaseLoader(object): | ||||
""" | """ | ||||
各个 Loader 的基类,提供了 API 的参考。 | 各个 Loader 的基类,提供了 API 的参考。 | ||||
@@ -55,8 +56,6 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
def _download_from_url(url, path): | def _download_from_url(url, path): | ||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -110,18 +109,16 @@ def _uncompress(src, dst): | |||||
raise ValueError('unsupported file {}'.format(src)) | raise ValueError('unsupported file {}'.format(src)) | ||||
class DataInfo: | |||||
class DataBundle: | |||||
""" | """ | ||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | ||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | ||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | ||||
""" | """ | ||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||||
self.vocabs = vocabs or {} | self.vocabs = vocabs or {} | ||||
self.embeddings = embeddings or {} | |||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -133,6 +130,7 @@ class DataInfo: | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | _str += '\t{} has {} entries.\n'.format(name, len(vocab)) | ||||
return _str | return _str | ||||
class DataSetLoader: | class DataSetLoader: | ||||
""" | """ | ||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | ||||
@@ -203,21 +201,20 @@ class DataSetLoader: | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: | |||||
""" | """ | ||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | ||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | ||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | ||||
返回的 :class:`DataInfo` 对象有如下属性: | |||||
返回的 :class:`DataBundle` 对象有如下属性: | |||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | - vocabs: 由从数据集中获取的词表组成的字典,每个词表 | ||||
- embeddings: (可选) 数据集对应的词嵌入 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | ||||
:param paths: 原始数据读取的路径 | :param paths: 原始数据读取的路径 | ||||
:param options: 根据不同的任务和数据集,设计自己的参数 | :param options: 根据不同的任务和数据集,设计自己的参数 | ||||
:return: 返回一个 DataInfo | |||||
:return: 返回一个 DataBundle | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError |
@@ -0,0 +1,35 @@ | |||||
""" | |||||
用于读数据集的模块, 具体包括: | |||||
这些模块的使用方法如下: | |||||
""" | |||||
__all__ = [ | |||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'IMDBLoader', | |||||
'MatchingLoader', | |||||
'MNLILoader', | |||||
'MTL16Loader', | |||||
'PeopleDailyCorpusLoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
'SSTLoader', | |||||
'SST2Loader', | |||||
'SNLILoader', | |||||
'YelpLoader', | |||||
] | |||||
from .conll import ConllLoader, Conll2003Loader | |||||
from .imdb import IMDBLoader | |||||
from .matching import MatchingLoader | |||||
from .mnli import MNLILoader | |||||
from .mtl import MTL16Loader | |||||
from .people_daily import PeopleDailyCorpusLoader | |||||
from .qnli import QNLILoader | |||||
from .quora import QuoraLoader | |||||
from .rte import RTELoader | |||||
from .snli import SNLILoader | |||||
from .sst import SSTLoader, SST2Loader | |||||
from .yelp import YelpLoader |
@@ -0,0 +1,73 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ..base_loader import DataSetLoader | |||||
from ..file_reader import _read_conll | |||||
class ConllLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||||
该符号在conll 2003中被用为文档分割符。 | |||||
列号从0开始, 每列对应内容为:: | |||||
Column Type | |||||
0 Document ID | |||||
1 Part number | |||||
2 Word number | |||||
3 Word itself | |||||
4 Part-of-Speech | |||||
5 Parse bit | |||||
6 Predicate lemma | |||||
7 Predicate Frameset ID | |||||
8 Word sense | |||||
9 Speaker/Author | |||||
10 Named Entities | |||||
11:N Predicate Arguments | |||||
N Coreference | |||||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||||
""" | |||||
def __init__(self, headers, indexes=None, dropna=False): | |||||
super(ConllLoader, self).__init__() | |||||
if not isinstance(headers, (list, tuple)): | |||||
raise TypeError( | |||||
'invalid headers: {}, should be list of strings'.format(headers)) | |||||
self.headers = headers | |||||
self.dropna = dropna | |||||
if indexes is None: | |||||
self.indexes = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexes) != len(headers): | |||||
raise ValueError | |||||
self.indexes = indexes | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
class Conll2003Loader(ConllLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` | |||||
读取Conll2003数据 | |||||
关于数据集的更多信息,参考: | |||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'tokens', 'pos', 'chunks', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) |
@@ -0,0 +1,96 @@ | |||||
from typing import Union, Dict | |||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
from ..base_loader import DataSetLoader, DataBundle | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
from ..utils import get_tokenizer | |||||
class IMDBLoader(DataSetLoader): | |||||
""" | |||||
读取IMDB数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
""" | |||||
def __init__(self): | |||||
super(IMDBLoader, self).__init__() | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
dataset = DataSet() | |||||
with open(path, 'r', encoding="utf-8") as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if not line: | |||||
continue | |||||
parts = line.split('\t') | |||||
target = parts[0] | |||||
words = self.tokenizer(parts[1].lower()) | |||||
dataset.append(Instance(words=words, target=target)) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
@@ -0,0 +1,248 @@ | |||||
import os | |||||
from typing import Union, Dict, List | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ...modules.encoder._bert import BertTokenizer | |||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 待读取数据集的路径名 | |||||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||||
的原始字符串文本,第三个为标签 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, | |||||
extra_split: List[str]=None, ) -> DataBundle: | |||||
""" | |||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||||
对应的全路径文件名。 | |||||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||||
这个数据集的名字,如果不定义则默认为train。 | |||||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||||
:param bool get_index: 是否需要根据词表将文本转为index | |||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||||
于此同时其他field不会被设置为input。默认值为True。 | |||||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||||
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。 | |||||
:return: | |||||
""" | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataBundle() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.set_target(Const.TARGET) | |||||
if extra_split is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
for s in extra_split: | |||||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||||
new_field_name=Const.INPUTS(0)) | |||||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||||
new_field_name=Const.INPUTS(0)) | |||||
_filt = lambda x: x | |||||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), | |||||
new_field_name=Const.INPUTS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), | |||||
new_field_name=Const.INPUTS(1), is_input=auto_set_input) | |||||
_filt = None | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||||
lines = f.readlines() | |||||
lines = [line.strip() for line in lines] | |||||
words_vocab.add_word_lst(lines) | |||||
words_vocab.build_vocab() | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is None: | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)], | |||||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||||
if 'train' not in n]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||||
return data_info |
@@ -0,0 +1,60 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -0,0 +1,65 @@ | |||||
from typing import Union, Dict | |||||
from ..base_loader import DataBundle | |||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||||
from ...core.const import Const | |||||
from ..utils import check_dataloader_paths | |||||
class MTL16Loader(CSVLoader): | |||||
""" | |||||
读取MTL16数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||||
""" | |||||
def __init__(self): | |||||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||||
def _load(self, path): | |||||
dataset = super(MTL16Loader, self)._load(path) | |||||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None,): | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info |
@@ -0,0 +1,85 @@ | |||||
from ..base_loader import DataSetLoader | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader` | |||||
读取人民日报数据集 | |||||
""" | |||||
def __init__(self, pos=True, ner=True): | |||||
super(PeopleDailyCorpusLoader, self).__init__() | |||||
self.pos = pos | |||||
self.ner = ner | |||||
def _load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
sents = f.readlines() | |||||
examples = [] | |||||
for sent in sents: | |||||
if len(sent) <= 2: | |||||
continue | |||||
inside_ne = False | |||||
sent_pos_tag = [] | |||||
sent_words = [] | |||||
sent_ner = [] | |||||
words = sent.strip().split()[1:] | |||||
for word in words: | |||||
if "[" in word and "]" in word: | |||||
ner_tag = "U" | |||||
print(word) | |||||
elif "[" in word: | |||||
inside_ne = True | |||||
ner_tag = "B" | |||||
word = word[1:] | |||||
elif "]" in word: | |||||
ner_tag = "L" | |||||
word = word[:word.index("]")] | |||||
if inside_ne is True: | |||||
inside_ne = False | |||||
else: | |||||
raise RuntimeError("only ] appears!") | |||||
else: | |||||
if inside_ne is True: | |||||
ner_tag = "I" | |||||
else: | |||||
ner_tag = "O" | |||||
tmp = word.split("/") | |||||
token, pos = tmp[0], tmp[1] | |||||
sent_ner.append(ner_tag) | |||||
sent_pos_tag.append(pos) | |||||
sent_words.append(token) | |||||
example = [sent_words] | |||||
if self.pos is True: | |||||
example.append(sent_pos_tag) | |||||
if self.ner is True: | |||||
example.append(sent_ner) | |||||
examples.append(example) | |||||
return self.convert(examples) | |||||
def convert(self, data): | |||||
""" | |||||
:param data: python 内置对象 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
data_set = DataSet() | |||||
for item in data: | |||||
sent_words = item[0] | |||||
if self.pos is True and self.ner is True: | |||||
instance = Instance( | |||||
words=sent_words, pos_tags=item[1], ner=item[2]) | |||||
elif self.pos is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||||
elif self.ner is True: | |||||
instance = Instance(words=sent_words, ner=item[1]) | |||||
else: | |||||
instance = Instance(words=sent_words) | |||||
data_set.append(instance) | |||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||||
return data_set |
@@ -0,0 +1,45 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -0,0 +1,32 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -0,0 +1,45 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -0,0 +1,44 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import JsonLoader | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -1,18 +1,19 @@ | |||||
from typing import Iterable | |||||
from typing import Union, Dict | |||||
from nltk import Tree | from nltk import Tree | ||||
from ..base_loader import DataInfo, DataSetLoader | |||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.const import Const | |||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
""" | """ | ||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||||
读取SST数据集, DataSet包含fields:: | 读取SST数据集, DataSet包含fields:: | ||||
@@ -25,6 +26,9 @@ class SSTLoader(DataSetLoader): | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | ||||
""" | """ | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
def __init__(self, subtree=False, fine_grained=False): | def __init__(self, subtree=False, fine_grained=False): | ||||
self.subtree = subtree | self.subtree = subtree | ||||
@@ -34,6 +38,7 @@ class SSTLoader(DataSetLoader): | |||||
tag_v['0'] = tag_v['1'] | tag_v['0'] = tag_v['1'] | ||||
tag_v['4'] = tag_v['3'] | tag_v['4'] = tag_v['3'] | ||||
self.tag_v = tag_v | self.tag_v = tag_v | ||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
@@ -52,29 +57,37 @@ class SSTLoader(DataSetLoader): | |||||
ds.append(Instance(words=words, target=tag)) | ds.append(Instance(words=words, target=tag)) | ||||
return ds | return ds | ||||
@staticmethod | |||||
def _get_one(data, subtree): | |||||
def _get_one(self, data, subtree): | |||||
tree = Tree.fromstring(data) | tree = Tree.fromstring(data) | ||||
if subtree: | if subtree: | ||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
return [(self.tokenizer(' '.join(t.leaves())), t.label()) for t in tree.subtrees() ] | |||||
return [(self.tokenizer(' '.join(tree.leaves())), tree.label())] | |||||
def process(self, | def process(self, | ||||
paths, | |||||
train_ds: Iterable[str] = None, | |||||
paths, train_subtree=True, | |||||
src_vocab_op: VocabularyOption = None, | src_vocab_op: VocabularyOption = None, | ||||
tgt_vocab_op: VocabularyOption = None, | |||||
src_embed_op: EmbeddingOption = None): | |||||
tgt_vocab_op: VocabularyOption = None,): | |||||
paths = check_dataloader_paths(paths) | |||||
input_name, target_name = 'words', 'target' | input_name, target_name = 'words', 'target' | ||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | ||||
info = DataInfo(datasets=self.load(paths)) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
info = DataBundle() | |||||
origin_subtree = self.subtree | |||||
self.subtree = train_subtree | |||||
info.datasets['train'] = self._load(paths['train']) | |||||
self.subtree = origin_subtree | |||||
for n, p in paths.items(): | |||||
if n != 'train': | |||||
info.datasets[n] = self._load(p) | |||||
src_vocab.from_dataset( | |||||
info.datasets['train'], | |||||
field_name=input_name, | |||||
no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) | |||||
tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) | |||||
src_vocab.index_dataset( | src_vocab.index_dataset( | ||||
*info.datasets.values(), | *info.datasets.values(), | ||||
field_name=input_name, new_field_name=input_name) | field_name=input_name, new_field_name=input_name) | ||||
@@ -86,10 +99,77 @@ class SSTLoader(DataSetLoader): | |||||
target_name: tgt_vocab | target_name: tgt_vocab | ||||
} | } | ||||
if src_embed_op is not None: | |||||
src_embed_op.vocab = src_vocab | |||||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
info.embeddings[input_name] = init_emb | |||||
return info | |||||
class SST2Loader(CSVLoader): | |||||
""" | |||||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
""" | |||||
def __init__(self): | |||||
super(SST2Loader, self).__init__(sep='\t') | |||||
self.tokenizer = get_tokenizer() | |||||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||||
def _load(self, path: str) -> DataSet: | |||||
ds = super(SST2Loader, self)._load(path) | |||||
for k, v in self.field.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||||
print("all count:", len(ds)) | |||||
return ds | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs={} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | return info | ||||
@@ -0,0 +1,127 @@ | |||||
import csv | |||||
from typing import Iterable | |||||
from ...core.const import Const | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from typing import Union, Dict | |||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
class YelpLoader(DataSetLoader): | |||||
""" | |||||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
chars:list(str),未index的字符列表 | |||||
数据集:yelp_full/yelp_polarity | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
:param lower: 是否需要自动转小写,默认为False。 | |||||
""" | |||||
def __init__(self, fine_grained=False, lower=False): | |||||
super(YelpLoader, self).__init__() | |||||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
'4.0': 'positive', '5.0': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['1.0'] = tag_v['2.0'] | |||||
tag_v['5.0'] = tag_v['4.0'] | |||||
self.fine_grained = fine_grained | |||||
self.tag_v = tag_v | |||||
self.lower = lower | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
csv_reader = csv.reader(open(path, encoding='utf-8')) | |||||
all_count = 0 | |||||
real_count = 0 | |||||
for row in csv_reader: | |||||
all_count += 1 | |||||
if len(row) == 2: | |||||
target = self.tag_v[row[0] + ".0"] | |||||
words = clean_str(row[1], self.tokenizer, self.lower) | |||||
if len(words) != 0: | |||||
ds.append(Instance(words=words, target=target)) | |||||
real_count += 1 | |||||
print("all count:", all_count) | |||||
print("real count:", real_count) | |||||
return ds | |||||
def process(self, paths: Union[str, Dict[str, str]], | |||||
train_ds: Iterable[str] = None, | |||||
src_vocab_op: VocabularyOption = None, | |||||
tgt_vocab_op: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_dataloader_paths(paths) | |||||
info = DataBundle(datasets=self.load(paths)) | |||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs = {} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in info.datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
else: | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||||
info.vocabs[input_name] = src_vocab | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
tgt_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=target_name, new_field_name=target_name) | |||||
info.vocabs[target_name] = tgt_vocab | |||||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
def clean_str(sentence, tokenizer, char_lower=False): | |||||
""" | |||||
heavily borrowed from github | |||||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
:param sentence: is a str | |||||
:return: | |||||
""" | |||||
if char_lower: | |||||
sentence = sentence.lower() | |||||
import re | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
words = tokenizer(sentence) | |||||
words_collection = [] | |||||
for word in words: | |||||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
continue | |||||
tt = nonalpnum.split(word) | |||||
t = ''.join(tt) | |||||
if t != '': | |||||
words_collection.append(t) | |||||
return words_collection | |||||
@@ -15,202 +15,13 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||||
__all__ = [ | __all__ = [ | ||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | |||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | |||||
'Conll2003Loader', | |||||
] | ] | ||||
import os | |||||
from nltk import Tree | |||||
from typing import Union, Dict | |||||
from ..core.vocabulary import Vocabulary | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | |||||
from .base_loader import DataSetLoader, DataInfo | |||||
from .data_loader.sst import SSTLoader | |||||
from ..core.const import Const | |||||
from ..modules.encoder._bert import BertTokenizer | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader` | |||||
读取人民日报数据集 | |||||
""" | |||||
def __init__(self, pos=True, ner=True): | |||||
super(PeopleDailyCorpusLoader, self).__init__() | |||||
self.pos = pos | |||||
self.ner = ner | |||||
def _load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
sents = f.readlines() | |||||
examples = [] | |||||
for sent in sents: | |||||
if len(sent) <= 2: | |||||
continue | |||||
inside_ne = False | |||||
sent_pos_tag = [] | |||||
sent_words = [] | |||||
sent_ner = [] | |||||
words = sent.strip().split()[1:] | |||||
for word in words: | |||||
if "[" in word and "]" in word: | |||||
ner_tag = "U" | |||||
print(word) | |||||
elif "[" in word: | |||||
inside_ne = True | |||||
ner_tag = "B" | |||||
word = word[1:] | |||||
elif "]" in word: | |||||
ner_tag = "L" | |||||
word = word[:word.index("]")] | |||||
if inside_ne is True: | |||||
inside_ne = False | |||||
else: | |||||
raise RuntimeError("only ] appears!") | |||||
else: | |||||
if inside_ne is True: | |||||
ner_tag = "I" | |||||
else: | |||||
ner_tag = "O" | |||||
tmp = word.split("/") | |||||
token, pos = tmp[0], tmp[1] | |||||
sent_ner.append(ner_tag) | |||||
sent_pos_tag.append(pos) | |||||
sent_words.append(token) | |||||
example = [sent_words] | |||||
if self.pos is True: | |||||
example.append(sent_pos_tag) | |||||
if self.ner is True: | |||||
example.append(sent_ner) | |||||
examples.append(example) | |||||
return self.convert(examples) | |||||
def convert(self, data): | |||||
""" | |||||
:param data: python 内置对象 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
data_set = DataSet() | |||||
for item in data: | |||||
sent_words = item[0] | |||||
if self.pos is True and self.ner is True: | |||||
instance = Instance( | |||||
words=sent_words, pos_tags=item[1], ner=item[2]) | |||||
elif self.pos is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||||
elif self.ner is True: | |||||
instance = Instance(words=sent_words, ner=item[1]) | |||||
else: | |||||
instance = Instance(words=sent_words) | |||||
data_set.append(instance) | |||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
return data_set | |||||
class ConllLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | |||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||||
该符号在conll 2003中被用为文档分割符。 | |||||
列号从0开始, 每列对应内容为:: | |||||
Column Type | |||||
0 Document ID | |||||
1 Part number | |||||
2 Word number | |||||
3 Word itself | |||||
4 Part-of-Speech | |||||
5 Parse bit | |||||
6 Predicate lemma | |||||
7 Predicate Frameset ID | |||||
8 Word sense | |||||
9 Speaker/Author | |||||
10 Named Entities | |||||
11:N Predicate Arguments | |||||
N Coreference | |||||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||||
""" | |||||
def __init__(self, headers, indexes=None, dropna=False): | |||||
super(ConllLoader, self).__init__() | |||||
if not isinstance(headers, (list, tuple)): | |||||
raise TypeError( | |||||
'invalid headers: {}, should be list of strings'.format(headers)) | |||||
self.headers = headers | |||||
self.dropna = dropna | |||||
if indexes is None: | |||||
self.indexes = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexes) != len(headers): | |||||
raise ValueError | |||||
self.indexes = indexes | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
class Conll2003Loader(ConllLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` | |||||
读取Conll2003数据 | |||||
关于数据集的更多信息,参考: | |||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'tokens', 'pos', 'chunks', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) | |||||
def _cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 | |||||
所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
from .file_reader import _read_csv, _read_json | |||||
from .base_loader import DataSetLoader | |||||
class JsonLoader(DataSetLoader): | class JsonLoader(DataSetLoader): | ||||
@@ -249,42 +60,6 @@ class JsonLoader(DataSetLoader): | |||||
return ds | return ds | ||||
class SNLILoader(JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self): | |||||
fields = { | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
super(SNLILoader, self).__init__(fields=fields) | |||||
def _load(self, path): | |||||
ds = super(SNLILoader, self)._load(path) | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class CSVLoader(DataSetLoader): | class CSVLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | ||||
@@ -311,6 +86,36 @@ class CSVLoader(DataSetLoader): | |||||
return ds | return ds | ||||
def _cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 | |||||
所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
def _add_seg_tag(data): | def _add_seg_tag(data): | ||||
""" | """ | ||||
@@ -17,6 +17,10 @@ PRETRAINED_BERT_MODEL_DIR = { | |||||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | 'en-large-uncased': 'bert-large-uncased-20939f45.zip', | ||||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | ||||
'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', | |||||
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', | |||||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | 'cn': 'bert-base-chinese-29d0a84a.zip', | ||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | 'cn-base': 'bert-base-chinese-29d0a84a.zip', | ||||
@@ -68,6 +72,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||||
"unable to parse {} as a URL or as a local path".format(url_or_filename) | "unable to parse {} as a URL or as a local path".format(url_or_filename) | ||||
) | ) | ||||
def get_filepath(filepath): | def get_filepath(filepath): | ||||
""" | """ | ||||
如果filepath中只有一个文件,则直接返回对应的全路径 | 如果filepath中只有一个文件,则直接返回对应的全路径 | ||||
@@ -82,6 +87,7 @@ def get_filepath(filepath): | |||||
return filepath | return filepath | ||||
return filepath | return filepath | ||||
def get_defalt_path(): | def get_defalt_path(): | ||||
""" | """ | ||||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | ||||
@@ -98,6 +104,7 @@ def get_defalt_path(): | |||||
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) | fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) | ||||
return fastnlp_cache_dir | return fastnlp_cache_dir | ||||
def _get_base_url(name): | def _get_base_url(name): | ||||
# 返回的URL结尾必须是/ | # 返回的URL结尾必须是/ | ||||
if 'FASTNLP_BASE_URL' in os.environ: | if 'FASTNLP_BASE_URL' in os.environ: | ||||
@@ -105,6 +112,7 @@ def _get_base_url(name): | |||||
return fastnlp_base_url | return fastnlp_base_url | ||||
raise RuntimeError("There function is not available right now.") | raise RuntimeError("There function is not available right now.") | ||||
def split_filename_suffix(filepath): | def split_filename_suffix(filepath): | ||||
""" | """ | ||||
给定filepath返回对应的name和suffix | 给定filepath返回对应的name和suffix | ||||
@@ -116,6 +124,7 @@ def split_filename_suffix(filepath): | |||||
return filename[:-7], '.tar.gz' | return filename[:-7], '.tar.gz' | ||||
return os.path.splitext(filename) | return os.path.splitext(filename) | ||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | def get_from_cache(url: str, cache_dir: Path = None) -> Path: | ||||
""" | """ | ||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | ||||
@@ -226,6 +235,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
return get_filepath(cache_path) | return get_filepath(cache_path) | ||||
def unzip_file(file: Path, to: Path): | def unzip_file(file: Path, to: Path): | ||||
# unpack and write out in CoNLL column-like format | # unpack and write out in CoNLL column-like format | ||||
from zipfile import ZipFile | from zipfile import ZipFile | ||||
@@ -234,13 +244,15 @@ def unzip_file(file: Path, to: Path): | |||||
# Extract all the contents of zip file in current directory | # Extract all the contents of zip file in current directory | ||||
zipObj.extractall(to) | zipObj.extractall(to) | ||||
def untar_gz_file(file:Path, to:Path): | def untar_gz_file(file:Path, to:Path): | ||||
import tarfile | import tarfile | ||||
with tarfile.open(file, 'r:gz') as tar: | with tarfile.open(file, 'r:gz') as tar: | ||||
tar.extractall(to) | tar.extractall(to) | ||||
def match_file(dir_name:str, cache_dir:str)->str: | |||||
def match_file(dir_name: str, cache_dir: str) -> str: | |||||
""" | """ | ||||
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 | 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 | ||||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | ||||
@@ -261,6 +273,7 @@ def match_file(dir_name:str, cache_dir:str)->str: | |||||
else: | else: | ||||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
cache_dir = Path('caches') | cache_dir = Path('caches') | ||||
cache_dir = None | cache_dir = None | ||||
@@ -0,0 +1,69 @@ | |||||
import os | |||||
from typing import Union, Dict | |||||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
""" | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
如果paths为不合法的,将直接进行raise相应的错误 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | |||||
""" | |||||
if isinstance(paths, str): | |||||
if os.path.isfile(paths): | |||||
return {'train': paths} | |||||
elif os.path.isdir(paths): | |||||
filenames = os.listdir(paths) | |||||
files = {} | |||||
for filename in filenames: | |||||
path_pair = None | |||||
if 'train' in filename: | |||||
path_pair = ('train', filename) | |||||
if 'dev' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('dev', filename) | |||||
if 'test' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('test', filename) | |||||
if path_pair: | |||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||||
return files | |||||
else: | |||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||||
elif isinstance(paths, dict): | |||||
if paths: | |||||
if 'train' not in paths: | |||||
raise KeyError("You have to include `train` in your dict.") | |||||
for key, value in paths.items(): | |||||
if isinstance(key, str) and isinstance(value, str): | |||||
if not os.path.isfile(value): | |||||
raise TypeError(f"{value} is not a valid file.") | |||||
else: | |||||
raise TypeError("All keys and values in paths should be str.") | |||||
return paths | |||||
else: | |||||
raise ValueError("Empty paths is not allowed.") | |||||
else: | |||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||||
def get_tokenizer(): | |||||
try: | |||||
import spacy | |||||
spacy.prefer_gpu() | |||||
en = spacy.load('en') | |||||
print('use spacy tokenizer') | |||||
return lambda x: [w.text for w in en.tokenizer(x)] | |||||
except Exception as e: | |||||
print('use raw tokenizer') | |||||
return lambda x: x.split() |
@@ -8,35 +8,7 @@ from torch import nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
class BertConfig: | |||||
def __init__( | |||||
self, | |||||
vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02 | |||||
): | |||||
self.vocab_size = vocab_size | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.intermediate_size = intermediate_size | |||||
self.hidden_act = hidden_act | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
from ..modules.encoder._bert import BertConfig | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
pooled_output = self.dropout(pooled_output) | pooled_output = self.dropout(pooled_output) | ||||
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, 1) | self.classifier = nn.Linear(config.hidden_size, 1) | ||||
@classmethod | |||||
def from_pretrained(cls, num_choices, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | ||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | ||||
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
sequence_output = self.dropout(sequence_output) | sequence_output = self.dropout(sequence_output) | ||||
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | ||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | # self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
logits = self.qa_outputs(sequence_output) | logits = self.qa_outputs(sequence_output) | ||||
@@ -4,149 +4,209 @@ __all__ = [ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
from .base_model import BaseModel | |||||
from ..core.const import Const | |||||
from ..modules import decoder as Decoder | |||||
from ..modules import encoder as Encoder | |||||
from ..modules import aggregator as Aggregator | |||||
from ..core.utils import seq_len_to_mask | |||||
from torch.nn import CrossEntropyLoss | |||||
my_inf = 10e12 | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.modules.encoder.embedding import TokenEmbedding | |||||
from fastNLP.modules.encoder.lstm import LSTM | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
class ESIM(BaseModel): | class ESIM(BaseModel): | ||||
""" | |||||
别名::class:`fastNLP.models.ESIM` :class:`fastNLP.models.snli.ESIM` | |||||
ESIM模型的一个PyTorch实现。 | |||||
ESIM模型的论文: Enhanced LSTM for Natural Language Inference (arXiv: 1609.06038) | |||||
"""ESIM model的一个PyTorch实现 | |||||
论文参见: https://arxiv.org/pdf/1609.06038.pdf | |||||
:param int vocab_size: 词表大小 | |||||
:param int embed_dim: 词嵌入维度 | |||||
:param int hidden_size: LSTM隐层大小 | |||||
:param float dropout: dropout大小,默认为0 | |||||
:param int num_classes: 标签数目,默认为3 | |||||
:param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 | |||||
:param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding | |||||
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度 | |||||
:param int num_labels: 目标标签种类数量,默认值为3 | |||||
:param float dropout_rate: dropout的比率,默认值为0.3 | |||||
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1 | |||||
""" | """ | ||||
def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): | |||||
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||||
dropout_embed=0.1): | |||||
super(ESIM, self).__init__() | super(ESIM, self).__init__() | ||||
self.vocab_size = vocab_size | |||||
self.embed_dim = embed_dim | |||||
self.hidden_size = hidden_size | |||||
self.dropout = dropout | |||||
self.n_labels = num_classes | |||||
self.drop = nn.Dropout(self.dropout) | |||||
self.embedding = Encoder.Embedding( | |||||
(self.vocab_size, self.embed_dim), dropout=self.dropout, | |||||
) | |||||
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) | |||||
self.encoder = Encoder.LSTM( | |||||
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, | |||||
batch_first=True, bidirectional=True | |||||
) | |||||
self.bi_attention = Aggregator.BiAttention() | |||||
self.mean_pooling = Aggregator.AvgPoolWithMask() | |||||
self.max_pooling = Aggregator.MaxPoolWithMask() | |||||
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) | |||||
self.decoder = Encoder.LSTM( | |||||
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, | |||||
batch_first=True, bidirectional=True | |||||
) | |||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | |||||
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | |||||
""" Forward function | |||||
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 | |||||
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 | |||||
:param torch.LongTensor seq_len1: [B] premise的长度 | |||||
:param torch.LongTensor seq_len2: [B] hypothesis的长度 | |||||
:param torch.LongTensor target: [B] 真实目标值 | |||||
:return: dict prediction: [B, n_labels(N)] 预测结果 | |||||
self.embedding = init_embedding | |||||
self.dropout_embed = EmbedDropout(p=dropout_embed) | |||||
if hidden_size is None: | |||||
hidden_size = self.embedding.embed_size | |||||
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||||
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||||
self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), | |||||
nn.Linear(8 * hidden_size, hidden_size), | |||||
nn.ReLU()) | |||||
nn.init.xavier_uniform_(self.interfere[1].weight.data) | |||||
self.bi_attention = SoftmaxAttention() | |||||
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) | |||||
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), | |||||
nn.Linear(8 * hidden_size, hidden_size), | |||||
nn.Tanh(), | |||||
nn.Dropout(p=dropout_rate), | |||||
nn.Linear(hidden_size, num_labels)) | |||||
self.dropout_rnn = nn.Dropout(p=dropout_rate) | |||||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||||
""" | """ | ||||
premise0 = self.embedding_layer(self.embedding(words1)) | |||||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | |||||
if seq_len1 is not None: | |||||
seq_len1 = seq_len_to_mask(seq_len1) | |||||
else: | |||||
seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) | |||||
seq_len1 = (seq_len1.long()).to(device=premise0.device) | |||||
if seq_len2 is not None: | |||||
seq_len2 = seq_len_to_mask(seq_len2) | |||||
else: | |||||
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) | |||||
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) | |||||
_BP, _PSL, _HP = premise0.size() | |||||
_BH, _HSL, _HH = hypothesis0.size() | |||||
_BPL, _PLL = seq_len1.size() | |||||
_HPL, _HLL = seq_len2.size() | |||||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | |||||
assert _HP == _HH | |||||
assert _PSL == _PLL and _HSL == _HLL | |||||
B, PL, H = premise0.size() | |||||
B, HL, H = hypothesis0.size() | |||||
a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] | |||||
b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] | |||||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | |||||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | |||||
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | |||||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | |||||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | |||||
f_ma = self.inference_layer(ma) | |||||
f_mb = self.inference_layer(mb) | |||||
vat = self.decoder(self.drop(f_ma)) | |||||
vbt = self.decoder(self.drop(f_mb)) | |||||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | |||||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | |||||
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | |||||
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | |||||
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | |||||
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | |||||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | |||||
prediction = torch.tanh(self.output(v)) # prediction: [B, N] | |||||
if target is not None: | |||||
func = nn.CrossEntropyLoss() | |||||
loss = func(prediction, target) | |||||
return {Const.OUTPUT: prediction, Const.LOSS: loss} | |||||
return {Const.OUTPUT: prediction} | |||||
def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): | |||||
""" Predict function | |||||
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 | |||||
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 | |||||
:param torch.LongTensor seq_len1: [B] premise的长度 | |||||
:param torch.LongTensor seq_len2: [B] hypothesis的长度 | |||||
:param torch.LongTensor target: [B] 真实目标值 | |||||
:return: dict prediction: [B, n_labels(N)] 预测结果 | |||||
:param words1: [batch, seq_len] | |||||
:param words2: [batch, seq_len] | |||||
:param seq_len1: [batch] | |||||
:param seq_len2: [batch] | |||||
:param target: | |||||
:return: | |||||
""" | """ | ||||
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] | |||||
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} | |||||
mask1 = seq_len_to_mask(seq_len1, words1.size(1)) | |||||
mask2 = seq_len_to_mask(seq_len2, words2.size(1)) | |||||
a0 = self.embedding(words1) # B * len * emb_dim | |||||
b0 = self.embedding(words2) | |||||
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) | |||||
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] | |||||
b = self.rnn(b0, mask2.byte()) | |||||
# a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] | |||||
# b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) | |||||
ai, bi = self.bi_attention(a, mask1, b, mask2) | |||||
a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] | |||||
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) | |||||
a_f = self.interfere(a_) | |||||
b_f = self.interfere(b_) | |||||
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] | |||||
b_h = self.rnn_high(b_f, mask2.byte()) | |||||
# a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] | |||||
# b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) | |||||
a_avg = self.mean_pooling(a_h, mask1, dim=1) | |||||
a_max, _ = self.max_pooling(a_h, mask1, dim=1) | |||||
b_avg = self.mean_pooling(b_h, mask2, dim=1) | |||||
b_max, _ = self.max_pooling(b_h, mask2, dim=1) | |||||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||||
logits = torch.tanh(self.classifier(out)) | |||||
if target is not None: | |||||
loss_fct = CrossEntropyLoss() | |||||
loss = loss_fct(logits, target) | |||||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||||
else: | |||||
return {Const.OUTPUT: logits} | |||||
def predict(self, **kwargs): | |||||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||||
return {Const.OUTPUT: pred} | |||||
# input [batch_size, len , hidden] | |||||
# mask [batch_size, len] (111...00) | |||||
@staticmethod | |||||
def mean_pooling(input, mask, dim=1): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||||
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) | |||||
@staticmethod | |||||
def max_pooling(input, mask, dim=1): | |||||
my_inf = 10e12 | |||||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||||
masks = masks.expand(-1, -1, input.size(2)).float() | |||||
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) | |||||
class EmbedDropout(nn.Dropout): | |||||
def forward(self, sequences_batch): | |||||
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) | |||||
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) | |||||
return dropout_mask.unsqueeze(1) * sequences_batch | |||||
class BiRNN(nn.Module): | |||||
def __init__(self, input_size, hidden_size, dropout_rate=0.3): | |||||
super(BiRNN, self).__init__() | |||||
self.dropout_rate = dropout_rate | |||||
self.rnn = nn.LSTM(input_size, hidden_size, | |||||
num_layers=1, | |||||
bidirectional=True, | |||||
batch_first=True) | |||||
def forward(self, x, x_mask): | |||||
# Sort x | |||||
lengths = x_mask.data.eq(1).long().sum(1) | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||||
lengths = list(lengths[idx_sort]) | |||||
x = x.index_select(0, idx_sort) | |||||
# Pack it up | |||||
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) | |||||
# Apply dropout to input | |||||
if self.dropout_rate > 0: | |||||
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | |||||
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | |||||
output = self.rnn(rnn_input)[0] | |||||
# Unpack everything | |||||
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | |||||
output = output.index_select(0, idx_unsort) | |||||
if output.size(1) != x_mask.size(1): | |||||
padding = torch.zeros(output.size(0), | |||||
x_mask.size(1) - output.size(1), | |||||
output.size(2)).type(output.data.type()) | |||||
output = torch.cat([output, padding], 1) | |||||
return output | |||||
def masked_softmax(tensor, mask): | |||||
tensor_shape = tensor.size() | |||||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||||
# Reshape the mask so it matches the size of the input tensor. | |||||
while mask.dim() < tensor.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.expand_as(tensor).contiguous().float() | |||||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||||
result = result * reshaped_mask | |||||
# 1e-13 is added to avoid divisions by zero. | |||||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||||
return result.view(*tensor_shape) | |||||
def weighted_sum(tensor, weights, mask): | |||||
w_sum = weights.bmm(tensor) | |||||
while mask.dim() < w_sum.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.transpose(-1, -2) | |||||
mask = mask.expand_as(w_sum).contiguous().float() | |||||
return w_sum * mask | |||||
class SoftmaxAttention(nn.Module): | |||||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||||
.contiguous()) | |||||
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) | |||||
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) | |||||
.contiguous(), | |||||
premise_mask) | |||||
attended_premises = weighted_sum(hypothesis_batch, | |||||
prem_hyp_attn, | |||||
premise_mask) | |||||
attended_hypotheses = weighted_sum(premise_batch, | |||||
hyp_prem_attn, | |||||
hypothesis_mask) | |||||
return attended_premises, attended_hypotheses |
@@ -47,7 +47,7 @@ class StarTransEnc(nn.Module): | |||||
self.embedding = get_embeddings(init_embed) | self.embedding = get_embeddings(init_embed) | ||||
emb_dim = self.embedding.embedding_dim | emb_dim = self.embedding.embedding_dim | ||||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | self.emb_fc = nn.Linear(emb_dim, hidden_size) | ||||
self.emb_drop = nn.Dropout(emb_dropout) | |||||
# self.emb_drop = nn.Dropout(emb_dropout) | |||||
self.encoder = StarTransformer(hidden_size=hidden_size, | self.encoder = StarTransformer(hidden_size=hidden_size, | ||||
num_layers=num_layers, | num_layers=num_layers, | ||||
num_head=num_head, | num_head=num_head, | ||||
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module): | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | [batch, hidden] 全局 relay 节点, 详见论文 | ||||
""" | """ | ||||
x = self.embedding(x) | x = self.embedding(x) | ||||
x = self.emb_fc(self.emb_drop(x)) | |||||
x = self.emb_fc(x) | |||||
nodes, relay = self.encoder(x, mask) | nodes, relay = self.encoder(x, mask) | ||||
return nodes, relay | return nodes, relay | ||||
@@ -205,7 +205,7 @@ class STSeqCls(nn.Module): | |||||
max_len=max_len, | max_len=max_len, | ||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | |||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | """ | ||||
@@ -1,11 +1,11 @@ | |||||
""" | """ | ||||
大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、 | 大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、 | ||||
聚合 :mod:`~fastNLP.modules.aggregator` 、解码 :mod:`~fastNLP.modules.decoder` 三种模块组成。 | |||||
解码 :mod:`~fastNLP.modules.decoder` 两种模块组成。 | |||||
.. image:: figures/text_classification.png | .. image:: figures/text_classification.png | ||||
:mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。 | :mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。 | ||||
三种模块的功能和常见组件如下: | |||||
两种模块的功能和常见组件如下: | |||||
+-----------------------+-----------------------+-----------------------+ | +-----------------------+-----------------------+-----------------------+ | ||||
| module type | functionality | example | | | module type | functionality | example | | ||||
@@ -13,9 +13,6 @@ | |||||
| encoder | 将输入编码为具有具 | embedding, RNN, CNN, | | | encoder | 将输入编码为具有具 | embedding, RNN, CNN, | | ||||
| | 有表示能力的向量 | transformer | | | | 有表示能力的向量 | transformer | | ||||
+-----------------------+-----------------------+-----------------------+ | +-----------------------+-----------------------+-----------------------+ | ||||
| aggregator | 从多个向量中聚合信息 | self-attention, | | |||||
| | | max-pooling | | |||||
+-----------------------+-----------------------+-----------------------+ | |||||
| decoder | 将具有某种表示意义的 | MLP, CRF | | | decoder | 将具有某种表示意义的 | MLP, CRF | | ||||
| | 向量解码为需要的输出 | | | | | 向量解码为需要的输出 | | | ||||
| | 形式 | | | | | 形式 | | | ||||
@@ -46,10 +43,8 @@ __all__ = [ | |||||
"allowed_transitions", | "allowed_transitions", | ||||
] | ] | ||||
from . import aggregator | |||||
from . import decoder | from . import decoder | ||||
from . import encoder | from . import encoder | ||||
from .aggregator import * | |||||
from .decoder import * | from .decoder import * | ||||
from .dropout import TimestepDropout | from .dropout import TimestepDropout | ||||
from .encoder import * | from .encoder import * | ||||
@@ -1,14 +0,0 @@ | |||||
__all__ = [ | |||||
"MaxPool", | |||||
"MaxPoolWithMask", | |||||
"AvgPool", | |||||
"MultiHeadAttention", | |||||
] | |||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -15,7 +15,8 @@ class MLP(nn.Module): | |||||
多层感知器 | 多层感知器 | ||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | ||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||||
sigmoid,默认值为relu | |||||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | ||||
:param str initial_method: 参数初始化方式 | :param str initial_method: 参数初始化方式 | ||||
:param float dropout: dropout概率,默认值为0 | :param float dropout: dropout概率,默认值为0 | ||||
@@ -22,7 +22,14 @@ __all__ = [ | |||||
"VarRNN", | "VarRNN", | ||||
"VarLSTM", | "VarLSTM", | ||||
"VarGRU" | |||||
"VarGRU", | |||||
"MaxPool", | |||||
"MaxPoolWithMask", | |||||
"AvgPool", | |||||
"AvgPoolWithMask", | |||||
"MultiHeadAttention", | |||||
] | ] | ||||
from ._bert import BertModel | from ._bert import BertModel | ||||
from .bert import BertWordPieceEncoder | from .bert import BertWordPieceEncoder | ||||
@@ -34,3 +41,6 @@ from .lstm import LSTM | |||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | from .variational_rnn import VarRNN, VarLSTM, VarGRU | ||||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -26,6 +26,7 @@ import sys | |||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
class BertConfig(object): | class BertConfig(object): | ||||
"""Configuration class to store the configuration of a `BertModel`. | """Configuration class to store the configuration of a `BertModel`. | ||||
""" | """ | ||||
@@ -339,13 +340,19 @@ class BertModel(nn.Module): | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | 如果你想使用预训练好的权重矩阵,请在以下网址下载. | ||||
sources:: | sources:: | ||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", | |||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" | |||||
用预训练权重矩阵来建立BERT模型:: | 用预训练权重矩阵来建立BERT模型:: | ||||
@@ -562,6 +569,7 @@ class WordpieceTokenizer(object): | |||||
output_tokens.extend(sub_tokens) | output_tokens.extend(sub_tokens) | ||||
return output_tokens | return output_tokens | ||||
def load_vocab(vocab_file): | def load_vocab(vocab_file): | ||||
"""Loads a vocabulary file into a dictionary.""" | """Loads a vocabulary file into a dictionary.""" | ||||
vocab = collections.OrderedDict() | vocab = collections.OrderedDict() | ||||
@@ -692,6 +700,7 @@ class BasicTokenizer(object): | |||||
output.append(char) | output.append(char) | ||||
return "".join(output) | return "".join(output) | ||||
def _is_whitespace(char): | def _is_whitespace(char): | ||||
"""Checks whether `chars` is a whitespace character.""" | """Checks whether `chars` is a whitespace character.""" | ||||
# \t, \n, and \r are technically contorl characters but we treat them | # \t, \n, and \r are technically contorl characters but we treat them | ||||
@@ -1,22 +1,21 @@ | |||||
""" | """ | ||||
这个页面的代码大量参考了https://github.com/HIT-SCIR/ELMoForManyLangs/tree/master/elmoformanylangs | |||||
这个页面的代码大量参考了 allenNLP | |||||
""" | """ | ||||
from typing import Optional, Tuple, List, Callable | from typing import Optional, Tuple, List, Callable | ||||
import os | import os | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
import json | import json | ||||
import pickle | |||||
from ..utils import get_dropout_mask | from ..utils import get_dropout_mask | ||||
import codecs | import codecs | ||||
from torch import autograd | |||||
class LstmCellWithProjection(torch.nn.Module): | class LstmCellWithProjection(torch.nn.Module): | ||||
""" | """ | ||||
@@ -58,6 +57,7 @@ class LstmCellWithProjection(torch.nn.Module): | |||||
respectively. The first dimension is 1 in order to match the Pytorch | respectively. The first dimension is 1 in order to match the Pytorch | ||||
API for returning stacked LSTM states. | API for returning stacked LSTM states. | ||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
input_size: int, | input_size: int, | ||||
hidden_size: int, | hidden_size: int, | ||||
@@ -129,13 +129,13 @@ class LstmCellWithProjection(torch.nn.Module): | |||||
# We have to use this '.data.new().fill_' pattern to create tensors with the correct | # We have to use this '.data.new().fill_' pattern to create tensors with the correct | ||||
# type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. | # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. | ||||
output_accumulator = inputs.data.new(batch_size, | output_accumulator = inputs.data.new(batch_size, | ||||
total_timesteps, | |||||
self.hidden_size).fill_(0) | |||||
total_timesteps, | |||||
self.hidden_size).fill_(0) | |||||
if initial_state is None: | if initial_state is None: | ||||
full_batch_previous_memory = inputs.data.new(batch_size, | full_batch_previous_memory = inputs.data.new(batch_size, | ||||
self.cell_size).fill_(0) | |||||
self.cell_size).fill_(0) | |||||
full_batch_previous_state = inputs.data.new(batch_size, | full_batch_previous_state = inputs.data.new(batch_size, | ||||
self.hidden_size).fill_(0) | |||||
self.hidden_size).fill_(0) | |||||
else: | else: | ||||
full_batch_previous_state = initial_state[0].squeeze(0) | full_batch_previous_state = initial_state[0].squeeze(0) | ||||
full_batch_previous_memory = initial_state[1].squeeze(0) | full_batch_previous_memory = initial_state[1].squeeze(0) | ||||
@@ -169,7 +169,7 @@ class LstmCellWithProjection(torch.nn.Module): | |||||
# Second conditional: Does the next shortest sequence beyond the current batch | # Second conditional: Does the next shortest sequence beyond the current batch | ||||
# index require computation use this timestep? | # index require computation use this timestep? | ||||
while current_length_index < (len(batch_lengths) - 1) and \ | while current_length_index < (len(batch_lengths) - 1) and \ | ||||
batch_lengths[current_length_index + 1] > index: | |||||
batch_lengths[current_length_index + 1] > index: | |||||
current_length_index += 1 | current_length_index += 1 | ||||
# Actually get the slices of the batch which we | # Actually get the slices of the batch which we | ||||
@@ -243,23 +243,23 @@ class LstmbiLm(nn.Module): | |||||
def __init__(self, config): | def __init__(self, config): | ||||
super(LstmbiLm, self).__init__() | super(LstmbiLm, self).__init__() | ||||
self.config = config | self.config = config | ||||
self.encoder = nn.LSTM(self.config['encoder']['projection_dim'], | |||||
self.config['encoder']['dim'], | |||||
num_layers=self.config['encoder']['n_layers'], | |||||
self.encoder = nn.LSTM(self.config['lstm']['projection_dim'], | |||||
self.config['lstm']['dim'], | |||||
num_layers=self.config['lstm']['n_layers'], | |||||
bidirectional=True, | bidirectional=True, | ||||
batch_first=True, | batch_first=True, | ||||
dropout=self.config['dropout']) | dropout=self.config['dropout']) | ||||
self.projection = nn.Linear(self.config['encoder']['dim'], self.config['encoder']['projection_dim'], bias=True) | |||||
self.projection = nn.Linear(self.config['lstm']['dim'], self.config['lstm']['projection_dim'], bias=True) | |||||
def forward(self, inputs, seq_len): | def forward(self, inputs, seq_len): | ||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | ||||
inputs = inputs[sort_idx] | inputs = inputs[sort_idx] | ||||
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) | inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) | ||||
output, hx = self.encoder(inputs, None) # -> [N,L,C] | output, hx = self.encoder(inputs, None) # -> [N,L,C] | ||||
output, _ = nn.util.rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||||
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
output = output[unsort_idx] | output = output[unsort_idx] | ||||
forward, backward = output.split(self.config['encoder']['dim'], 2) | |||||
forward, backward = output.split(self.config['lstm']['dim'], 2) | |||||
return torch.cat([self.projection(forward), self.projection(backward)], dim=2) | return torch.cat([self.projection(forward), self.projection(backward)], dim=2) | ||||
@@ -267,13 +267,13 @@ class ElmobiLm(torch.nn.Module): | |||||
def __init__(self, config): | def __init__(self, config): | ||||
super(ElmobiLm, self).__init__() | super(ElmobiLm, self).__init__() | ||||
self.config = config | self.config = config | ||||
input_size = config['encoder']['projection_dim'] | |||||
hidden_size = config['encoder']['projection_dim'] | |||||
cell_size = config['encoder']['dim'] | |||||
num_layers = config['encoder']['n_layers'] | |||||
memory_cell_clip_value = config['encoder']['cell_clip'] | |||||
state_projection_clip_value = config['encoder']['proj_clip'] | |||||
recurrent_dropout_probability = config['dropout'] | |||||
input_size = config['lstm']['projection_dim'] | |||||
hidden_size = config['lstm']['projection_dim'] | |||||
cell_size = config['lstm']['dim'] | |||||
num_layers = config['lstm']['n_layers'] | |||||
memory_cell_clip_value = config['lstm']['cell_clip'] | |||||
state_projection_clip_value = config['lstm']['proj_clip'] | |||||
recurrent_dropout_probability = 0.0 | |||||
self.input_size = input_size | self.input_size = input_size | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
@@ -316,13 +316,13 @@ class ElmobiLm(torch.nn.Module): | |||||
:param seq_len: batch_size | :param seq_len: batch_size | ||||
:return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size | :return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size | ||||
""" | """ | ||||
max_len = inputs.size(1) | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | ||||
inputs = inputs[sort_idx] | inputs = inputs[sort_idx] | ||||
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) | inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) | ||||
output, _ = self._lstm_forward(inputs, None) | output, _ = self._lstm_forward(inputs, None) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
output = output[:, unsort_idx] | output = output[:, unsort_idx] | ||||
return output | return output | ||||
def _lstm_forward(self, | def _lstm_forward(self, | ||||
@@ -399,7 +399,7 @@ class ElmobiLm(torch.nn.Module): | |||||
torch.cat([forward_state[1], backward_state[1]], -1))) | torch.cat([forward_state[1], backward_state[1]], -1))) | ||||
stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) | stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) | ||||
# Stack the hidden state and memory for each layer into 2 tensors of shape | |||||
# Stack the hidden state and memory for each layer in。to 2 tensors of shape | |||||
# (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) | # (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) | ||||
# respectively. | # respectively. | ||||
final_hidden_states, final_memory_states = zip(*final_states) | final_hidden_states, final_memory_states = zip(*final_states) | ||||
@@ -409,63 +409,30 @@ class ElmobiLm(torch.nn.Module): | |||||
return stacked_sequence_outputs, final_state_tuple | return stacked_sequence_outputs, final_state_tuple | ||||
class LstmTokenEmbedder(nn.Module): | |||||
def __init__(self, config, word_emb_layer, char_emb_layer): | |||||
super(LstmTokenEmbedder, self).__init__() | |||||
self.config = config | |||||
self.word_emb_layer = word_emb_layer | |||||
self.char_emb_layer = char_emb_layer | |||||
self.output_dim = config['encoder']['projection_dim'] | |||||
emb_dim = 0 | |||||
if word_emb_layer is not None: | |||||
emb_dim += word_emb_layer.n_d | |||||
if char_emb_layer is not None: | |||||
emb_dim += char_emb_layer.n_d * 2 | |||||
self.char_lstm = nn.LSTM(char_emb_layer.n_d, char_emb_layer.n_d, num_layers=1, bidirectional=True, | |||||
batch_first=True, dropout=config['dropout']) | |||||
self.projection = nn.Linear(emb_dim, self.output_dim, bias=True) | |||||
def forward(self, words, chars): | |||||
embs = [] | |||||
if self.word_emb_layer is not None: | |||||
if hasattr(self, 'words_to_words'): | |||||
words = self.words_to_words[words] | |||||
word_emb = self.word_emb_layer(words) | |||||
embs.append(word_emb) | |||||
if self.char_emb_layer is not None: | |||||
batch_size, seq_len, _ = chars.shape | |||||
chars = chars.view(batch_size * seq_len, -1) | |||||
chars_emb = self.char_emb_layer(chars) | |||||
# TODO 这里应该要考虑seq_len的问题 | |||||
_, (chars_outputs, __) = self.char_lstm(chars_emb) | |||||
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2) | |||||
embs.append(chars_outputs) | |||||
token_embedding = torch.cat(embs, dim=2) | |||||
return self.projection(token_embedding) | |||||
class ConvTokenEmbedder(nn.Module): | class ConvTokenEmbedder(nn.Module): | ||||
def __init__(self, config, word_emb_layer, char_emb_layer): | |||||
def __init__(self, config, weight_file, word_emb_layer, char_emb_layer): | |||||
super(ConvTokenEmbedder, self).__init__() | super(ConvTokenEmbedder, self).__init__() | ||||
self.config = config | |||||
self.weight_file = weight_file | |||||
self.word_emb_layer = word_emb_layer | self.word_emb_layer = word_emb_layer | ||||
self.char_emb_layer = char_emb_layer | self.char_emb_layer = char_emb_layer | ||||
self.output_dim = config['encoder']['projection_dim'] | |||||
self.emb_dim = 0 | |||||
if word_emb_layer is not None: | |||||
self.emb_dim += word_emb_layer.weight.size(1) | |||||
self.output_dim = config['lstm']['projection_dim'] | |||||
self._options = config | |||||
char_cnn_options = self._options['char_cnn'] | |||||
if char_cnn_options['activation'] == 'tanh': | |||||
self.activation = torch.tanh | |||||
elif char_cnn_options['activation'] == 'relu': | |||||
self.activation = torch.nn.functional.relu | |||||
else: | |||||
raise Exception("Unknown activation") | |||||
if char_emb_layer is not None: | if char_emb_layer is not None: | ||||
self.convolutions = [] | |||||
cnn_config = config['token_embedder'] | |||||
self.char_conv = [] | |||||
cnn_config = config['char_cnn'] | |||||
filters = cnn_config['filters'] | filters = cnn_config['filters'] | ||||
char_embed_dim = cnn_config['char_dim'] | |||||
char_embed_dim = cnn_config['embedding']['dim'] | |||||
convolutions = [] | |||||
for i, (width, num) in enumerate(filters): | for i, (width, num) in enumerate(filters): | ||||
conv = torch.nn.Conv1d( | conv = torch.nn.Conv1d( | ||||
@@ -474,55 +441,56 @@ class ConvTokenEmbedder(nn.Module): | |||||
kernel_size=width, | kernel_size=width, | ||||
bias=True | bias=True | ||||
) | ) | ||||
self.convolutions.append(conv) | |||||
convolutions.append(conv) | |||||
self.add_module('char_conv_{}'.format(i), conv) | |||||
self.convolutions = nn.ModuleList(self.convolutions) | |||||
self._convolutions = convolutions | |||||
self.n_filters = sum(f[1] for f in filters) | |||||
self.n_highway = cnn_config['n_highway'] | |||||
n_filters = sum(f[1] for f in filters) | |||||
n_highway = cnn_config['n_highway'] | |||||
self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu) | |||||
self.emb_dim += self.n_filters | |||||
self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu) | |||||
self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True) | |||||
self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) | |||||
def forward(self, words, chars): | def forward(self, words, chars): | ||||
embs = [] | |||||
if self.word_emb_layer is not None: | |||||
if hasattr(self, 'words_to_words'): | |||||
words = self.words_to_words[words] | |||||
word_emb = self.word_emb_layer(words) | |||||
embs.append(word_emb) | |||||
if self.char_emb_layer is not None: | |||||
batch_size, seq_len, _ = chars.size() | |||||
chars = chars.view(batch_size * seq_len, -1) | |||||
character_embedding = self.char_emb_layer(chars) | |||||
character_embedding = torch.transpose(character_embedding, 1, 2) | |||||
cnn_config = self.config['token_embedder'] | |||||
if cnn_config['activation'] == 'tanh': | |||||
activation = torch.nn.functional.tanh | |||||
elif cnn_config['activation'] == 'relu': | |||||
activation = torch.nn.functional.relu | |||||
else: | |||||
raise Exception("Unknown activation") | |||||
convs = [] | |||||
for i in range(len(self.convolutions)): | |||||
convolved = self.convolutions[i](character_embedding) | |||||
# (batch_size * sequence_length, n_filters for this width) | |||||
convolved, _ = torch.max(convolved, dim=-1) | |||||
convolved = activation(convolved) | |||||
convs.append(convolved) | |||||
char_emb = torch.cat(convs, dim=-1) | |||||
char_emb = self.highways(char_emb) | |||||
embs.append(char_emb.view(batch_size, -1, self.n_filters)) | |||||
token_embedding = torch.cat(embs, dim=2) | |||||
return self.projection(token_embedding) | |||||
""" | |||||
:param words: | |||||
:param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: | |||||
:return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : | |||||
""" | |||||
# the character id embedding | |||||
# (batch_size * sequence_length, max_chars_per_token, embed_dim) | |||||
# character_embedding = torch.nn.functional.embedding( | |||||
# chars.view(-1, max_chars_per_token), | |||||
# self._char_embedding_weights | |||||
# ) | |||||
batch_size, sequence_length, max_char_len = chars.size() | |||||
character_embedding = self.char_emb_layer(chars).reshape(batch_size * sequence_length, max_char_len, -1) | |||||
# run convolutions | |||||
# (batch_size * sequence_length, embed_dim, max_chars_per_token) | |||||
character_embedding = torch.transpose(character_embedding, 1, 2) | |||||
convs = [] | |||||
for i in range(len(self._convolutions)): | |||||
conv = getattr(self, 'char_conv_{}'.format(i)) | |||||
convolved = conv(character_embedding) | |||||
# (batch_size * sequence_length, n_filters for this width) | |||||
convolved, _ = torch.max(convolved, dim=-1) | |||||
convolved = self.activation(convolved) | |||||
convs.append(convolved) | |||||
# (batch_size * sequence_length, n_filters) | |||||
token_embedding = torch.cat(convs, dim=-1) | |||||
# apply the highway layers (batch_size * sequence_length, n_filters) | |||||
token_embedding = self._highways(token_embedding) | |||||
# final projection (batch_size * sequence_length, embedding_dim) | |||||
token_embedding = self._projection(token_embedding) | |||||
# reshape to (batch_size, sequence_length+2, embedding_dim) | |||||
return token_embedding.view(batch_size, sequence_length, -1) | |||||
class Highway(torch.nn.Module): | class Highway(torch.nn.Module): | ||||
@@ -543,6 +511,7 @@ class Highway(torch.nn.Module): | |||||
activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) | activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) | ||||
The non-linearity to use in the highway layers. | The non-linearity to use in the highway layers. | ||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
input_dim: int, | input_dim: int, | ||||
num_layers: int = 1, | num_layers: int = 1, | ||||
@@ -573,6 +542,7 @@ class Highway(torch.nn.Module): | |||||
current_input = gate * linear_part + (1 - gate) * nonlinear_part | current_input = gate * linear_part + (1 - gate) * nonlinear_part | ||||
return current_input | return current_input | ||||
class _ElmoModel(nn.Module): | class _ElmoModel(nn.Module): | ||||
""" | """ | ||||
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | ||||
@@ -582,10 +552,30 @@ class _ElmoModel(nn.Module): | |||||
(4) 设计一个保存token的embedding,允许缓存word的表示。 | (4) 设计一个保存token的embedding,允许缓存word的表示。 | ||||
""" | """ | ||||
def __init__(self, model_dir:str, vocab:Vocabulary=None, cache_word_reprs:bool=False): | |||||
super(_ElmoModel, self).__init__() | |||||
config = json.load(open(os.path.join(model_dir, 'structure_config.json'), 'r')) | |||||
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): | |||||
super(_ElmoModel, self).__init__() | |||||
self.model_dir = model_dir | |||||
dir = os.walk(self.model_dir) | |||||
config_file = None | |||||
weight_file = None | |||||
config_count = 0 | |||||
weight_count = 0 | |||||
for path, dir_list, file_list in dir: | |||||
for file_name in file_list: | |||||
if file_name.__contains__(".json"): | |||||
config_file = file_name | |||||
config_count += 1 | |||||
elif file_name.__contains__(".pkl"): | |||||
weight_file = file_name | |||||
weight_count += 1 | |||||
if config_count > 1 or weight_count > 1: | |||||
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") | |||||
elif config_count == 0 or weight_count == 0: | |||||
raise Exception(f"No config file or weight file found in {model_dir}") | |||||
config = json.load(open(os.path.join(model_dir, config_file), 'r')) | |||||
self.weight_file = os.path.join(model_dir, weight_file) | |||||
self.config = config | self.config = config | ||||
OOV_TAG = '<oov>' | OOV_TAG = '<oov>' | ||||
@@ -595,152 +585,103 @@ class _ElmoModel(nn.Module): | |||||
BOW_TAG = '<bow>' | BOW_TAG = '<bow>' | ||||
EOW_TAG = '<eow>' | EOW_TAG = '<eow>' | ||||
# 将加载embedding放到这里 | |||||
token_embedder_states = torch.load(os.path.join(model_dir, 'token_embedder.pkl'), map_location='cpu') | |||||
# For the model trained with word form word encoder. | |||||
if config['token_embedder']['word_dim'] > 0: | |||||
word_lexicon = {} | |||||
with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi: | |||||
for line in fpi: | |||||
tokens = line.strip().split('\t') | |||||
if len(tokens) == 1: | |||||
tokens.insert(0, '\u3000') | |||||
token, i = tokens | |||||
word_lexicon[token] = int(i) | |||||
# 做一些sanity check | |||||
for special_word in [PAD_TAG, OOV_TAG, BOS_TAG, EOS_TAG]: | |||||
assert special_word in word_lexicon, f"{special_word} not found in word.dic." | |||||
# 根据vocab调整word_embedding | |||||
pre_word_embedding = token_embedder_states.pop('word_emb_layer.embedding.weight') | |||||
word_emb_layer = nn.Embedding(len(vocab)+2, config['token_embedder']['word_dim']) #多增加两个是为了<bos>与<eos> | |||||
found_word_count = 0 | |||||
for word, index in vocab: | |||||
if index == vocab.unknown_idx: # 因为fastNLP的unknow是<unk> 而在这里是<oov>所以ugly强制适配一下 | |||||
index_in_pre = word_lexicon[OOV_TAG] | |||||
found_word_count += 1 | |||||
elif index == vocab.padding_idx: # 需要pad对齐 | |||||
index_in_pre = word_lexicon[PAD_TAG] | |||||
found_word_count += 1 | |||||
elif word in word_lexicon: | |||||
index_in_pre = word_lexicon[word] | |||||
found_word_count += 1 | |||||
else: | |||||
index_in_pre = word_lexicon[OOV_TAG] | |||||
word_emb_layer.weight.data[index] = pre_word_embedding[index_in_pre] | |||||
print(f"{found_word_count} out of {len(vocab)} words were found in pretrained elmo embedding.") | |||||
word_emb_layer.weight.data[-1] = pre_word_embedding[word_lexicon[EOS_TAG]] | |||||
word_emb_layer.weight.data[-2] = pre_word_embedding[word_lexicon[BOS_TAG]] | |||||
self.word_vocab = vocab | |||||
else: | |||||
word_emb_layer = None | |||||
# For the model trained with character-based word encoder. | # For the model trained with character-based word encoder. | ||||
if config['token_embedder']['char_dim'] > 0: | |||||
char_lexicon = {} | |||||
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: | |||||
for line in fpi: | |||||
tokens = line.strip().split('\t') | |||||
if len(tokens) == 1: | |||||
tokens.insert(0, '\u3000') | |||||
token, i = tokens | |||||
char_lexicon[token] = int(i) | |||||
# 做一些sanity check | |||||
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: | |||||
assert special_word in char_lexicon, f"{special_word} not found in char.dic." | |||||
# 从vocab中构建char_vocab | |||||
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) | |||||
# 需要保证<bow>与<eow>在里面 | |||||
char_vocab.add_word(BOW_TAG) | |||||
char_vocab.add_word(EOW_TAG) | |||||
for word, index in vocab: | |||||
char_vocab.add_word_lst(list(word)) | |||||
# 保证<eos>, <bos>也在 | |||||
char_vocab.add_word_lst(list(BOS_TAG)) | |||||
char_vocab.add_word_lst(list(EOS_TAG)) | |||||
# 根据char_lexicon调整 | |||||
char_emb_layer = nn.Embedding(len(char_vocab), int(config['token_embedder']['char_dim'])) | |||||
pre_char_embedding = token_embedder_states.pop('char_emb_layer.embedding.weight') | |||||
found_char_count = 0 | |||||
for char, index in char_vocab: # 调整character embedding | |||||
if char in char_lexicon: | |||||
index_in_pre = char_lexicon.get(char) | |||||
found_char_count += 1 | |||||
else: | |||||
index_in_pre = char_lexicon[OOV_TAG] | |||||
char_emb_layer.weight.data[index] = pre_char_embedding[index_in_pre] | |||||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||||
# 生成words到chars的映射 | |||||
if config['token_embedder']['name'].lower() == 'cnn': | |||||
max_chars = config['token_embedder']['max_characters_per_token'] | |||||
elif config['token_embedder']['name'].lower() == 'lstm': | |||||
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow> | |||||
char_lexicon = {} | |||||
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: | |||||
for line in fpi: | |||||
tokens = line.strip().split('\t') | |||||
if len(tokens) == 1: | |||||
tokens.insert(0, '\u3000') | |||||
token, i = tokens | |||||
char_lexicon[token] = int(i) | |||||
# 做一些sanity check | |||||
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: | |||||
assert special_word in char_lexicon, f"{special_word} not found in char.dic." | |||||
# 从vocab中构建char_vocab | |||||
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) | |||||
# 需要保证<bow>与<eow>在里面 | |||||
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) | |||||
for word, index in vocab: | |||||
char_vocab.add_word_lst(list(word)) | |||||
self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx | |||||
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) | |||||
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), | |||||
padding_idx=len(char_vocab)) | |||||
# 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict | |||||
elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu') | |||||
char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight'] | |||||
found_char_count = 0 | |||||
for char, index in char_vocab: # 调整character embedding | |||||
if char in char_lexicon: | |||||
index_in_pre = char_lexicon.get(char) | |||||
found_char_count += 1 | |||||
else: | else: | ||||
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) | |||||
# 增加<bos>, <eos>所以加2. | |||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), | |||||
fill_value=char_vocab.to_index(PAD_TAG), dtype=torch.long), | |||||
requires_grad=False) | |||||
for word, index in vocab: | |||||
if len(word)+2>max_chars: | |||||
word = word[:max_chars-2] | |||||
if index==vocab.padding_idx: # 如果是pad的话,需要和给定的对齐 | |||||
word = PAD_TAG | |||||
elif index==vocab.unknown_idx: | |||||
word = OOV_TAG | |||||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] | |||||
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) | |||||
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) | |||||
for index, word in enumerate([BOS_TAG, EOS_TAG]): # 加上<eos>, <bos> | |||||
if len(word)+2>max_chars: | |||||
word = word[:max_chars-2] | |||||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] | |||||
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) | |||||
self.words_to_chars_embedding[index+len(vocab)] = torch.LongTensor(char_ids) | |||||
self.char_vocab = char_vocab | |||||
else: | |||||
char_emb_layer = None | |||||
if config['token_embedder']['name'].lower() == 'cnn': | |||||
self.token_embedder = ConvTokenEmbedder( | |||||
config, word_emb_layer, char_emb_layer) | |||||
elif config['token_embedder']['name'].lower() == 'lstm': | |||||
self.token_embedder = LstmTokenEmbedder( | |||||
config, word_emb_layer, char_emb_layer) | |||||
self.token_embedder.load_state_dict(token_embedder_states, strict=False) | |||||
if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk | |||||
words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False) | |||||
for word, idx in vocab: | |||||
if vocab._is_word_no_create_entry(word): | |||||
words_to_words[idx] = vocab.unknown_idx | |||||
setattr(self.token_embedder, 'words_to_words', words_to_words) | |||||
self.output_dim = config['encoder']['projection_dim'] | |||||
if config['encoder']['name'].lower() == 'elmo': | |||||
self.encoder = ElmobiLm(config) | |||||
elif config['encoder']['name'].lower() == 'lstm': | |||||
self.encoder = LstmbiLm(config) | |||||
self.encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pkl'), | |||||
map_location='cpu')) | |||||
self.bos_index = len(vocab) | |||||
self.eos_index = len(vocab) + 1 | |||||
self._pad_index = vocab.padding_idx | |||||
index_in_pre = char_lexicon[OOV_TAG] | |||||
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] | |||||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||||
# 生成words到chars的映射 | |||||
max_chars = config['char_cnn']['max_characters_per_token'] | |||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars), | |||||
fill_value=len(char_vocab), | |||||
dtype=torch.long), | |||||
requires_grad=False) | |||||
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: | |||||
if len(word) + 2 > max_chars: | |||||
word = word[:max_chars - 2] | |||||
if index == self._pad_index: | |||||
continue | |||||
elif word == BOS_TAG or word == EOS_TAG: | |||||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [ | |||||
char_vocab.to_index(EOW_TAG)] | |||||
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) | |||||
else: | |||||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [ | |||||
char_vocab.to_index(EOW_TAG)] | |||||
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) | |||||
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) | |||||
self.char_vocab = char_vocab | |||||
self.token_embedder = ConvTokenEmbedder( | |||||
config, self.weight_file, None, char_emb_layer) | |||||
elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight | |||||
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) | |||||
self.output_dim = config['lstm']['projection_dim'] | |||||
# lstm encoder | |||||
self.encoder = ElmobiLm(config) | |||||
self.encoder.load_state_dict(elmo_model["lstm"]) | |||||
if cache_word_reprs: | if cache_word_reprs: | ||||
if config['token_embedder']['char_dim']>0: # 只有在使用了chars的情况下有用 | |||||
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 | |||||
print("Start to generate cache word representations.") | print("Start to generate cache word representations.") | ||||
batch_size = 320 | batch_size = 320 | ||||
num_batches = self.words_to_chars_embedding.size(0)//batch_size + \ | |||||
int(self.words_to_chars_embedding.size(0)%batch_size!=0) | |||||
self.cached_word_embedding = nn.Embedding(self.words_to_chars_embedding.size(0), | |||||
config['encoder']['projection_dim']) | |||||
# bos eos | |||||
word_size = self.words_to_chars_embedding.size(0) | |||||
num_batches = word_size // batch_size + \ | |||||
int(word_size % batch_size != 0) | |||||
self.cached_word_embedding = nn.Embedding(word_size, | |||||
config['lstm']['projection_dim']) | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for i in range(num_batches): | for i in range(num_batches): | ||||
words = torch.arange(i*batch_size, min((i+1)*batch_size, self.words_to_chars_embedding.size(0))).long() | |||||
words = torch.arange(i * batch_size, | |||||
min((i + 1) * batch_size, word_size)).long() | |||||
chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars | chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars | ||||
word_reprs = self.token_embedder(words.unsqueeze(1), chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | |||||
word_reprs = self.token_embedder(words.unsqueeze(1), | |||||
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | |||||
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | ||||
print("Finish generating cached word representations. Going to delete the character encoder.") | print("Finish generating cached word representations. Going to delete the character encoder.") | ||||
del self.token_embedder, self.words_to_chars_embedding | del self.token_embedder, self.words_to_chars_embedding | ||||
else: | else: | ||||
@@ -758,8 +699,10 @@ class _ElmoModel(nn.Module): | |||||
seq_len = words.ne(self._pad_index).sum(dim=-1) | seq_len = words.ne(self._pad_index).sum(dim=-1) | ||||
expanded_words[:, 1:-1] = words | expanded_words[:, 1:-1] = words | ||||
expanded_words[:, 0].fill_(self.bos_index) | expanded_words[:, 0].fill_(self.bos_index) | ||||
expanded_words[torch.arange(batch_size).to(words), seq_len+1] = self.eos_index | |||||
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index | |||||
seq_len = seq_len + 2 | seq_len = seq_len + 2 | ||||
zero_tensor = expanded_words.new_zeros(expanded_words.shape) | |||||
mask = (expanded_words == zero_tensor).unsqueeze(-1) | |||||
if hasattr(self, 'cached_word_embedding'): | if hasattr(self, 'cached_word_embedding'): | ||||
token_embedding = self.cached_word_embedding(expanded_words) | token_embedding = self.cached_word_embedding(expanded_words) | ||||
else: | else: | ||||
@@ -767,22 +710,19 @@ class _ElmoModel(nn.Module): | |||||
chars = self.words_to_chars_embedding[expanded_words] | chars = self.words_to_chars_embedding[expanded_words] | ||||
else: | else: | ||||
chars = None | chars = None | ||||
token_embedding = self.token_embedder(expanded_words, chars) | |||||
if self.config['encoder']['name'] == 'elmo': | |||||
encoder_output = self.encoder(token_embedding, seq_len) | |||||
if encoder_output.size(2) < max_len+2: | |||||
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, | |||||
max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) | |||||
encoder_output = torch.cat([encoder_output, dummy_tensor], 2) | |||||
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size | |||||
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) | |||||
encoder_output = torch.cat([token_embedding, encoder_output], dim=0) | |||||
elif self.config['encoder']['name'] == 'lstm': | |||||
encoder_output = self.encoder(token_embedding, seq_len) | |||||
else: | |||||
raise ValueError('Unknown encoder: {0}'.format(self.config['encoder']['name'])) | |||||
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim | |||||
encoder_output = self.encoder(token_embedding, seq_len) | |||||
if encoder_output.size(2) < max_len + 2: | |||||
num_layers, _, output_len, hidden_size = encoder_output.size() | |||||
dummy_tensor = encoder_output.new_zeros(num_layers, batch_size, | |||||
max_len + 2 - output_len, hidden_size) | |||||
encoder_output = torch.cat((encoder_output, dummy_tensor), 2) | |||||
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size | |||||
token_embedding = token_embedding.masked_fill(mask, 0) | |||||
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) | |||||
encoder_output = torch.cat((token_embedding, encoder_output), dim=0) | |||||
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 | # 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 | ||||
encoder_output = encoder_output[:, :, 1:-1] | encoder_output = encoder_output[:, :, 1:-1] | ||||
return encoder_output | return encoder_output |
@@ -8,9 +8,9 @@ import torch | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from ..dropout import TimestepDropout | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from ..utils import initial_parameter | |||||
from fastNLP.modules.utils import initial_parameter | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
@@ -19,7 +19,7 @@ class DotAttention(nn.Module): | |||||
补上文档 | 补上文档 | ||||
""" | """ | ||||
def __init__(self, key_size, value_size, dropout=0): | |||||
def __init__(self, key_size, value_size, dropout=0.0): | |||||
super(DotAttention, self).__init__() | super(DotAttention, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
@@ -37,7 +37,7 @@ class DotAttention(nn.Module): | |||||
""" | """ | ||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
if mask_out is not None: | if mask_out is not None: | ||||
output.masked_fill_(mask_out, -1e8) | |||||
output.masked_fill_(mask_out, -1e18) | |||||
output = self.softmax(output) | output = self.softmax(output) | ||||
output = self.drop(output) | output = self.drop(output) | ||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
@@ -45,8 +45,7 @@ class DotAttention(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | class MultiHeadAttention(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention` | |||||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention` | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | ||||
:param key_size: int, 每个head的维度大小。 | :param key_size: int, 每个head的维度大小。 | ||||
@@ -67,9 +66,8 @@ class MultiHeadAttention(nn.Module): | |||||
self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
self.v_in = nn.Linear(input_size, in_size) | self.v_in = nn.Linear(input_size, in_size) | ||||
# follow the paper, do not apply dropout within dot-product | # follow the paper, do not apply dropout within dot-product | ||||
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) | |||||
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) | |||||
self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
self.drop = TimestepDropout(dropout) | |||||
self.reset_parameters() | self.reset_parameters() | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
@@ -105,7 +103,7 @@ class MultiHeadAttention(nn.Module): | |||||
# concat all heads, do output linear | # concat all heads, do output linear | ||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | ||||
output = self.drop(self.out(atte)) | |||||
output = self.out(atte) | |||||
return output | return output | ||||
@@ -2,35 +2,22 @@ | |||||
import os | import os | ||||
from torch import nn | from torch import nn | ||||
import torch | import torch | ||||
from ...io.file_utils import _get_base_url, cached_path | |||||
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ._bert import _WordPieceBertModel, BertModel | from ._bert import _WordPieceBertModel, BertModel | ||||
class BertWordPieceEncoder(nn.Module): | class BertWordPieceEncoder(nn.Module): | ||||
""" | """ | ||||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | ||||
:param fastNLP.Vocabulary vocab: 词表 | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | ||||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | ||||
:param bool requires_grad: 是否需要gradient。 | :param bool requires_grad: 是否需要gradient。 | ||||
""" | """ | ||||
def __init__(self, model_dir_or_name:str='en-base-uncased', layers:str='-1', | |||||
requires_grad:bool=False): | |||||
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||||
requires_grad: bool=False): | |||||
super().__init__() | super().__init__() | ||||
PRETRAIN_URL = _get_base_url('bert') | PRETRAIN_URL = _get_base_url('bert') | ||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
} | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | ||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | ||||
@@ -89,4 +76,4 @@ class BertWordPieceEncoder(nn.Module): | |||||
outputs = self.model(word_pieces, token_type_ids) | outputs = self.model(word_pieces, token_type_ids) | ||||
outputs = torch.cat([*outputs], dim=-1) | outputs = torch.cat([*outputs], dim=-1) | ||||
return outputs | |||||
return outputs |
@@ -35,15 +35,15 @@ class Embedding(nn.Module): | |||||
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | ||||
def __init__(self, init_embed, dropout=0.0, dropout_word=0, unk_index=None): | |||||
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): | |||||
""" | """ | ||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding; | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding; | ||||
也可以传入TokenEmbedding对象 | |||||
:param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有 | |||||
一定的regularize的作用。 | |||||
:param float dropout: 对Embedding的输出的dropout。 | :param float dropout: 对Embedding的输出的dropout。 | ||||
:param float dropout_word: 按照一定比例随机将word设置为unk的idx,这样可以使得unk这个token得到足够的训练 | |||||
:param int unk_index: drop word时替换为的index,如果init_embed为TokenEmbedding不需要传入该值。 | |||||
:param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 | |||||
""" | """ | ||||
super(Embedding, self).__init__() | super(Embedding, self).__init__() | ||||
@@ -52,21 +52,21 @@ class Embedding(nn.Module): | |||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
if not isinstance(self.embed, TokenEmbedding): | if not isinstance(self.embed, TokenEmbedding): | ||||
self._embed_size = self.embed.weight.size(1) | self._embed_size = self.embed.weight.size(1) | ||||
if dropout_word>0 and not isinstance(unk_index, int): | |||||
if word_dropout>0 and not isinstance(unk_index, int): | |||||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | raise ValueError("When drop word is set, you need to pass in the unk_index.") | ||||
else: | else: | ||||
self._embed_size = self.embed.embed_size | self._embed_size = self.embed.embed_size | ||||
unk_index = self.embed.get_word_vocab().unknown_idx | unk_index = self.embed.get_word_vocab().unknown_idx | ||||
self.unk_index = unk_index | self.unk_index = unk_index | ||||
self.dropout_word = dropout_word | |||||
self.word_dropout = word_dropout | |||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.LongTensor x: [batch, seq_len] | :param torch.LongTensor x: [batch, seq_len] | ||||
:return: torch.Tensor : [batch, seq_len, embed_dim] | :return: torch.Tensor : [batch, seq_len, embed_dim] | ||||
""" | """ | ||||
if self.dropout_word>0 and self.training: | |||||
mask = torch.ones_like(x).float() * self.dropout_word | |||||
if self.word_dropout>0 and self.training: | |||||
mask = torch.ones_like(x).float() * self.word_dropout | |||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | ||||
x = x.masked_fill(mask, self.unk_index) | x = x.masked_fill(mask, self.unk_index) | ||||
x = self.embed(x) | x = self.embed(x) | ||||
@@ -117,11 +117,38 @@ class Embedding(nn.Module): | |||||
class TokenEmbedding(nn.Module): | class TokenEmbedding(nn.Module): | ||||
def __init__(self, vocab): | |||||
def __init__(self, vocab, word_dropout=0.0, dropout=0.0): | |||||
super(TokenEmbedding, self).__init__() | super(TokenEmbedding, self).__init__() | ||||
assert vocab.padding_idx is not None, "You vocabulary must have padding." | |||||
assert vocab.padding is not None, "Vocabulary must have a padding entry." | |||||
self._word_vocab = vocab | self._word_vocab = vocab | ||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
if word_dropout>0: | |||||
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." | |||||
self.word_dropout = word_dropout | |||||
self._word_unk_index = vocab.unknown_idx | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
def drop_word(self, words): | |||||
""" | |||||
按照设定随机将words设置为unknown_index。 | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
mask = torch.ones_like(words).float() * self.word_dropout | |||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | |||||
return words | |||||
def dropout(self, words): | |||||
""" | |||||
对embedding后的word表示进行drop。 | |||||
:param torch.FloatTensor words: batch_size x max_len x embed_size | |||||
:return: | |||||
""" | |||||
return self.dropout_layer(words) | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -147,8 +174,16 @@ class TokenEmbedding(nn.Module): | |||||
def embed_size(self) -> int: | def embed_size(self) -> int: | ||||
return self._embed_size | return self._embed_size | ||||
@property | |||||
def embedding_dim(self) -> int: | |||||
return self._embed_size | |||||
@property | @property | ||||
def num_embedding(self) -> int: | def num_embedding(self) -> int: | ||||
""" | |||||
这个值可能会大于实际的embedding矩阵的大小。 | |||||
:return: | |||||
""" | |||||
return len(self._word_vocab) | return len(self._word_vocab) | ||||
def get_word_vocab(self): | def get_word_vocab(self): | ||||
@@ -163,6 +198,9 @@ class TokenEmbedding(nn.Module): | |||||
def size(self): | def size(self): | ||||
return torch.Size(self.num_embedding, self._embed_size) | return torch.Size(self.num_embedding, self._embed_size) | ||||
@abstractmethod | |||||
def forward(self, *input): | |||||
raise NotImplementedError | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
""" | """ | ||||
@@ -179,15 +217,17 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | ||||
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, | 的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, | ||||
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | ||||
:param requires_grad: 是否需要gradient. 默认为True | |||||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 | |||||
:param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||||
:param bool requires_grad: 是否需要gradient. 默认为True | |||||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 | |||||
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, | def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, | ||||
normalize=False): | |||||
super(StaticEmbedding, self).__init__(vocab) | |||||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | |||||
lower=False, dropout=0, word_dropout=0, normalize=False): | |||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
# 得到cache_path | # 得到cache_path | ||||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | ||||
@@ -202,8 +242,40 @@ class StaticEmbedding(TokenEmbedding): | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
# 读取embedding | # 读取embedding | ||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||||
normalize=normalize) | |||||
if lower: | |||||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||||
for word, index in vocab: | |||||
if not vocab._is_word_no_create_entry(word): | |||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||||
for word in vocab._no_create_word.keys(): # 不需要创建entry的 | |||||
if word in vocab: | |||||
lowered_word = word.lower() | |||||
if lowered_word not in lowered_vocab.word_count: | |||||
lowered_vocab.add_word(lowered_word) | |||||
lowered_vocab._no_create_word[lowered_word] += 1 | |||||
print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " | |||||
f"words.") | |||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method, | |||||
normalize=normalize) | |||||
# 需要适配一下 | |||||
if not hasattr(self, 'words_to_words'): | |||||
self.words_to_words = torch.arange(len(lowered_vocab, )).long() | |||||
if lowered_vocab.unknown: | |||||
unknown_idx = lowered_vocab.unknown_idx | |||||
else: | |||||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow | |||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||||
requires_grad=False) | |||||
for word, index in vocab: | |||||
if word not in lowered_vocab: | |||||
word = word.lower() | |||||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 | |||||
continue | |||||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||||
self.words_to_words = words_to_words | |||||
else: | |||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||||
normalize=normalize) | |||||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | ||||
padding_idx=vocab.padding_idx, | padding_idx=vocab.padding_idx, | ||||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | max_norm=None, norm_type=2, scale_grad_by_freq=False, | ||||
@@ -301,7 +373,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
if vocab._no_create_word_length>0: | if vocab._no_create_word_length>0: | ||||
if vocab.unknown is None: # 创建一个专门的unknown | if vocab.unknown is None: # 创建一个专门的unknown | ||||
unknown_idx = len(matrix) | unknown_idx = len(matrix) | ||||
vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() | |||||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||||
else: | else: | ||||
unknown_idx = vocab.unknown_idx | unknown_idx = vocab.unknown_idx | ||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | ||||
@@ -330,12 +402,15 @@ class StaticEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
if hasattr(self, 'words_to_words'): | if hasattr(self, 'words_to_words'): | ||||
words = self.words_to_words[words] | words = self.words_to_words[words] | ||||
return self.embedding(words) | |||||
words = self.drop_word(words) | |||||
words = self.embedding(words) | |||||
words = self.dropout(words) | |||||
return words | |||||
class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
def __init__(self, vocab: Vocabulary): | |||||
super(ContextualEmbedding, self).__init__(vocab) | |||||
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | |||||
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): | def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): | ||||
""" | """ | ||||
@@ -438,19 +513,17 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | ||||
目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载 | 目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载 | ||||
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | :param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | ||||
按照这个顺序concat起来。默认为'2'。 | |||||
:param requires_grad: bool, 该层是否需要gradient. 默认为False | |||||
按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, | |||||
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) | |||||
:param requires_grad: bool, 该层是否需要gradient, 默认为False. | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, | :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, | ||||
并删除character encoder,之后将直接使用cache的embedding。默认为False。 | 并删除character encoder,之后将直接使用cache的embedding。默认为False。 | ||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', | |||||
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): | |||||
super(ElmoEmbedding, self).__init__(vocab) | |||||
layers = list(map(int, layers.split(','))) | |||||
assert len(layers) > 0, "Must choose one output" | |||||
for layer in layers: | |||||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||||
self.layers = layers | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False, | |||||
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool=False): | |||||
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | ||||
@@ -464,8 +537,49 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | ||||
if layers=='mix': | |||||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers']+1), | |||||
requires_grad=requires_grad) | |||||
self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad) | |||||
self._get_outputs = self._get_mixed_outputs | |||||
self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | |||||
else: | |||||
layers = list(map(int, layers.split(','))) | |||||
assert len(layers) > 0, "Must choose one output" | |||||
for layer in layers: | |||||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||||
self.layers = layers | |||||
self._get_outputs = self._get_layer_outputs | |||||
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 | |||||
def _get_mixed_outputs(self, outputs): | |||||
# outputs: num_layers x batch_size x max_len x hidden_size | |||||
# return: batch_size x max_len x hidden_size | |||||
weights = F.softmax(self.layer_weights+1/len(outputs), dim=0).to(outputs) | |||||
outputs = torch.einsum('l,lbij->bij', weights, outputs) | |||||
return self.gamma.to(outputs)*outputs | |||||
def set_mix_weights_requires_grad(self, flag=True): | |||||
""" | |||||
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 | |||||
该方法没有用。 | |||||
:param bool flag: 混合不同层表示的结果是否可以训练。 | |||||
:return: | |||||
""" | |||||
if hasattr(self, 'layer_weights'): | |||||
self.layer_weights.requires_grad = flag | |||||
self.gamma.requires_grad = flag | |||||
def _get_layer_outputs(self, outputs): | |||||
if len(self.layers) == 1: | |||||
outputs = outputs[self.layers[0]] | |||||
else: | |||||
outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1) | |||||
return outputs | |||||
def forward(self, words: torch.LongTensor): | def forward(self, words: torch.LongTensor): | ||||
""" | """ | ||||
@@ -476,19 +590,18 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
:param words: batch_size x max_len | :param words: batch_size x max_len | ||||
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) | :return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) | ||||
""" | """ | ||||
words = self.drop_word(words) | |||||
outputs = self._get_sent_reprs(words) | outputs = self._get_sent_reprs(words) | ||||
if outputs is not None: | if outputs is not None: | ||||
return outputs | |||||
return self.dropout(outputs) | |||||
outputs = self.model(words) | outputs = self.model(words) | ||||
if len(self.layers) == 1: | |||||
outputs = outputs[self.layers[0]] | |||||
else: | |||||
outputs = torch.cat([*outputs[self.layers]], dim=-1) | |||||
return outputs | |||||
outputs = self._get_outputs(outputs) | |||||
return self.dropout(outputs) | |||||
def _delete_model_weights(self): | def _delete_model_weights(self): | ||||
del self.layers, self.model | |||||
for name in ['layers', 'model', 'layer_weights', 'gamma']: | |||||
if hasattr(self, name): | |||||
delattr(self, name) | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -529,13 +642,16 @@ class BertEmbedding(ContextualEmbedding): | |||||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | ||||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | ||||
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 | 中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | ||||
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 | 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 | ||||
:param bool requires_grad: 是否需要gradient。 | :param bool requires_grad: 是否需要gradient。 | ||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | ||||
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | |||||
super(BertEmbedding, self).__init__(vocab) | |||||
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, | |||||
include_cls_sep: bool=False): | |||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
@@ -566,13 +682,14 @@ class BertEmbedding(ContextualEmbedding): | |||||
:param torch.LongTensor words: [batch_size, max_len] | :param torch.LongTensor words: [batch_size, max_len] | ||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | ||||
""" | """ | ||||
words = self.drop_word(words) | |||||
outputs = self._get_sent_reprs(words) | outputs = self._get_sent_reprs(words) | ||||
if outputs is not None: | if outputs is not None: | ||||
return outputs | |||||
return self.dropout(words) | |||||
outputs = self.model(words) | outputs = self.model(words) | ||||
outputs = torch.cat([*outputs], dim=-1) | outputs = torch.cat([*outputs], dim=-1) | ||||
return outputs | |||||
return self.dropout(words) | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -614,8 +731,8 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` | 别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` | ||||
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool | |||||
-> fc. 不同的kernel大小的fitler结果是concat起来的。 | |||||
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||||
不同的kernel大小的fitler结果是concat起来的。 | |||||
Example:: | Example:: | ||||
@@ -625,23 +742,24 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param embed_size: 该word embedding的大小,默认值为50. | :param embed_size: 该word embedding的大小,默认值为50. | ||||
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. | :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. | ||||
:param dropout: 以多大的概率drop | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率drop | |||||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | ||||
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. | :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. | ||||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | ||||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | ||||
:param min_char_freq: character的最少出现次数。默认值为2. | :param min_char_freq: character的最少出现次数。默认值为2. | ||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, | |||||
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', | |||||
activation='relu', min_char_freq: int=2): | |||||
super(CNNCharEmbedding, self).__init__(vocab) | |||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, | |||||
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), | |||||
pool_method: str='max', activation='relu', min_char_freq: int=2): | |||||
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
for kernel in kernel_sizes: | for kernel in kernel_sizes: | ||||
assert kernel % 2 == 1, "Only odd kernel is allowed." | assert kernel % 2 == 1, "Only odd kernel is allowed." | ||||
assert pool_method in ('max', 'avg') | assert pool_method in ('max', 'avg') | ||||
self.dropout = nn.Dropout(dropout, inplace=True) | |||||
self.dropout = nn.Dropout(dropout) | |||||
self.pool_method = pool_method | self.pool_method = pool_method | ||||
# activation function | # activation function | ||||
if isinstance(activation, str): | if isinstance(activation, str): | ||||
@@ -691,6 +809,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
:param words: [batch_size, max_len] | :param words: [batch_size, max_len] | ||||
:return: [batch_size, max_len, embed_size] | :return: [batch_size, max_len, embed_size] | ||||
""" | """ | ||||
words = self.drop_word(words) | |||||
batch_size, max_len = words.size() | batch_size, max_len = words.size() | ||||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | ||||
word_lengths = self.word_lengths[words] # batch_size x max_len | word_lengths = self.word_lengths[words] # batch_size x max_len | ||||
@@ -699,7 +818,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
# 为1的地方为mask | # 为1的地方为mask | ||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | ||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | ||||
self.dropout(chars) | |||||
chars = self.dropout(chars) | |||||
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | ||||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | ||||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | ||||
@@ -713,7 +832,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | ||||
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | ||||
chars = self.fc(chars) | chars = self.fc(chars) | ||||
return chars | |||||
return self.dropout(chars) | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -760,6 +879,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param embed_size: embedding的大小。默认值为50. | :param embed_size: embedding的大小。默认值为50. | ||||
:param char_emb_size: character的embedding的大小。默认值为50. | :param char_emb_size: character的embedding的大小。默认值为50. | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param dropout: 以多大概率drop | :param dropout: 以多大概率drop | ||||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | ||||
:param pool_method: 支持'max', 'avg' | :param pool_method: 支持'max', 'avg' | ||||
@@ -767,15 +887,16 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
:param min_char_freq: character的最小出现次数。默认值为2. | :param min_char_freq: character的最小出现次数。默认值为2. | ||||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | ||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50, | |||||
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): | |||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, | |||||
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, | |||||
bidirectional=True): | |||||
super(LSTMCharEmbedding, self).__init__(vocab) | super(LSTMCharEmbedding, self).__init__(vocab) | ||||
assert hidden_size % 2 == 0, "Only even kernel is allowed." | assert hidden_size % 2 == 0, "Only even kernel is allowed." | ||||
assert pool_method in ('max', 'avg') | assert pool_method in ('max', 'avg') | ||||
self.pool_method = pool_method | self.pool_method = pool_method | ||||
self.dropout = nn.Dropout(dropout, inplace=True) | |||||
self.dropout = nn.Dropout(dropout) | |||||
# activation function | # activation function | ||||
if isinstance(activation, str): | if isinstance(activation, str): | ||||
if activation.lower() == 'relu': | if activation.lower() == 'relu': | ||||
@@ -824,6 +945,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
:param words: [batch_size, max_len] | :param words: [batch_size, max_len] | ||||
:return: [batch_size, max_len, embed_size] | :return: [batch_size, max_len, embed_size] | ||||
""" | """ | ||||
words = self.drop_word(words) | |||||
batch_size, max_len = words.size() | batch_size, max_len = words.size() | ||||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | ||||
word_lengths = self.word_lengths[words] # batch_size x max_len | word_lengths = self.word_lengths[words] # batch_size x max_len | ||||
@@ -848,7 +970,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
chars = self.fc(chars) | chars = self.fc(chars) | ||||
return chars | |||||
return self.dropout(chars) | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -887,17 +1009,21 @@ class StackEmbedding(TokenEmbedding): | |||||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | |||||
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
""" | """ | ||||
def __init__(self, embeds: List[TokenEmbedding]): | |||||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | |||||
vocabs = [] | vocabs = [] | ||||
for embed in embeds: | for embed in embeds: | ||||
vocabs.append(embed.get_word_vocab()) | |||||
if hasattr(embed, 'get_word_vocab'): | |||||
vocabs.append(embed.get_word_vocab()) | |||||
_vocab = vocabs[0] | _vocab = vocabs[0] | ||||
for vocab in vocabs[1:]: | for vocab in vocabs[1:]: | ||||
assert vocab == _vocab, "All embeddings should use the same word vocabulary." | |||||
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." | |||||
super(StackEmbedding, self).__init__(_vocab) | |||||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | |||||
assert isinstance(embeds, list) | assert isinstance(embeds, list) | ||||
for embed in embeds: | for embed in embeds: | ||||
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | ||||
@@ -949,7 +1075,9 @@ class StackEmbedding(TokenEmbedding): | |||||
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 | :return: 返回的shape和当前这个stack embedding中embedding的组成有关 | ||||
""" | """ | ||||
outputs = [] | outputs = [] | ||||
words = self.drop_word(words) | |||||
for embed in self.embeds: | for embed in self.embeds: | ||||
outputs.append(embed(words)) | outputs.append(embed(words)) | ||||
return torch.cat(outputs, dim=-1) | |||||
outputs = self.dropout(torch.cat(outputs, dim=-1)) | |||||
return outputs | |||||
@@ -1,7 +1,8 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"AvgPool" | |||||
"AvgPool", | |||||
"AvgPoolWithMask" | |||||
] | ] | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -9,7 +10,7 @@ import torch.nn as nn | |||||
class MaxPool(nn.Module): | class MaxPool(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.aggregator.pooling.MaxPool` | |||||
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool` | |||||
Max-pooling模块。 | Max-pooling模块。 | ||||
@@ -58,7 +59,7 @@ class MaxPool(nn.Module): | |||||
class MaxPoolWithMask(nn.Module): | class MaxPoolWithMask(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.MaxPoolWithMask` | |||||
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask` | |||||
带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | 带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | ||||
""" | """ | ||||
@@ -98,7 +99,7 @@ class KMaxPool(nn.Module): | |||||
class AvgPool(nn.Module): | class AvgPool(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.aggregator.pooling.AvgPool` | |||||
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool` | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | ||||
""" | """ | ||||
@@ -125,7 +126,7 @@ class AvgPool(nn.Module): | |||||
class AvgPoolWithMask(nn.Module): | class AvgPoolWithMask(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.AvgPoolWithMask` | |||||
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask` | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | ||||
的时候只会考虑mask为1的位置 | 的时候只会考虑mask为1的位置 |
@@ -34,12 +34,14 @@ class StarTransformer(nn.Module): | |||||
super(StarTransformer, self).__init__() | super(StarTransformer, self).__init__() | ||||
self.iters = num_layers | self.iters = num_layers | ||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | |||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)]) | |||||
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||||
self.emb_drop = nn.Dropout(dropout) | |||||
self.ring_att = nn.ModuleList( | self.ring_att = nn.ModuleList( | ||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
self.star_att = nn.ModuleList( | self.star_att = nn.ModuleList( | ||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
if max_len is not None: | if max_len is not None: | ||||
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module): | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | ||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | ||||
if self.pos_emb: | |||||
if self.pos_emb and False: | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | ||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | ||||
embs = embs + P | embs = embs + P | ||||
embs = norm_func(self.emb_drop, embs) | |||||
nodes = embs | nodes = embs | ||||
relay = embs.mean(2, keepdim=True) | relay = embs.mean(2, keepdim=True) | ||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ||||
r_embs = embs.view(B, H, 1, L) | r_embs = embs.view(B, H, 1, L) | ||||
for i in range(self.iters): | for i in range(self.iters): | ||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ||||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | ||||
nodes = nodes.masked_fill_(ex_mask, 0) | nodes = nodes.masked_fill_(ex_mask, 0) | ||||
@@ -3,7 +3,7 @@ __all__ = [ | |||||
] | ] | ||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAttention | |||||
from fastNLP.modules.encoder.attention import MultiHeadAttention | |||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
@@ -8,7 +8,8 @@ import os | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from .utils import load_url | from .utils import load_url | ||||
from .processor import ModelProcessor | from .processor import ModelProcessor | ||||
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader | |||||
from fastNLP.io.dataset_loader import _cut_long_sentence | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from ..api.pipeline import Pipeline | from ..api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -1,110 +0,0 @@ | |||||
# Byte-compiled / optimized / DLL files | |||||
__pycache__/ | |||||
*.py[cod] | |||||
*$py.class | |||||
# C extensions | |||||
*.so | |||||
# Distribution / packaging | |||||
.Python | |||||
build/ | |||||
develop-eggs/ | |||||
dist/ | |||||
downloads/ | |||||
eggs/ | |||||
.eggs/ | |||||
lib/ | |||||
lib64/ | |||||
parts/ | |||||
sdist/ | |||||
var/ | |||||
wheels/ | |||||
*.egg-info/ | |||||
.installed.cfg | |||||
*.egg | |||||
MANIFEST | |||||
# PyInstaller | |||||
# Usually these files are written by a python script from a template | |||||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | |||||
*.manifest | |||||
*.spec | |||||
# Installer logs | |||||
pip-log.txt | |||||
pip-delete-this-directory.txt | |||||
# Unit test / coverage reports | |||||
htmlcov/ | |||||
.tox/ | |||||
.coverage | |||||
.coverage.* | |||||
.cache | |||||
nosetests.xml | |||||
coverage.xml | |||||
*.cover | |||||
.hypothesis/ | |||||
.pytest_cache/ | |||||
# Translations | |||||
*.mo | |||||
*.pot | |||||
# Django stuff: | |||||
*.log | |||||
local_settings.py | |||||
db.sqlite3 | |||||
# Flask stuff: | |||||
instance/ | |||||
.webassets-cache | |||||
# Scrapy stuff: | |||||
.scrapy | |||||
# Sphinx documentation | |||||
docs/_build/ | |||||
# PyBuilder | |||||
target/ | |||||
# Jupyter Notebook | |||||
.ipynb_checkpoints | |||||
# pyenv | |||||
.python-version | |||||
# celery beat schedule file | |||||
celerybeat-schedule | |||||
# SageMath parsed files | |||||
*.sage.py | |||||
# Environments | |||||
.env | |||||
.venv | |||||
env/ | |||||
venv/ | |||||
ENV/ | |||||
env.bak/ | |||||
venv.bak/ | |||||
# Spyder project settings | |||||
.spyderproject | |||||
.spyproject | |||||
# Rope project settings | |||||
.ropeproject | |||||
# mkdocs documentation | |||||
/site | |||||
# mypy | |||||
.mypy_cache | |||||
#custom | |||||
GoogleNews-vectors-negative300.bin/ | |||||
GoogleNews-vectors-negative300.bin.gz | |||||
models/ | |||||
*.swp |
@@ -1,77 +0,0 @@ | |||||
## Introduction | |||||
This is the implementation of [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) paper in PyTorch. | |||||
* MRDataset, non-static-model(word2vec rained by Mikolov etal. (2013) on 100 billion words of Google News) | |||||
* It can be run in both CPU and GPU | |||||
* The best accuracy is 82.61%, which is better than 81.5% in the paper | |||||
(by Jingyuan Liu @Fudan University; Email:(fdjingyuan@outlook.com) Welcome to discussion!) | |||||
## Requirement | |||||
* python 3.6 | |||||
* pytorch > 0.1 | |||||
* numpy | |||||
* gensim | |||||
## Run | |||||
STEP 1 | |||||
install packages like gensim (other needed pakages is the same) | |||||
``` | |||||
pip install gensim | |||||
``` | |||||
STEP 2 | |||||
install MRdataset and word2vec resources | |||||
* MRdataset: you can download the dataset in (https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz) | |||||
* word2vec: you can download the file in (https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit) | |||||
Since this file is more than 1.5G, I did not display in folders. If you download the file, please remember modify the path in Function def word_embeddings(path = './GoogleNews-vectors-negative300.bin/'): | |||||
STEP 3 | |||||
train the model | |||||
``` | |||||
python train.py | |||||
``` | |||||
you will get the information printed in the screen, like | |||||
``` | |||||
Epoch [1/20], Iter [100/192] Loss: 0.7008 | |||||
Test Accuracy: 71.869159 % | |||||
Epoch [2/20], Iter [100/192] Loss: 0.5957 | |||||
Test Accuracy: 75.700935 % | |||||
Epoch [3/20], Iter [100/192] Loss: 0.4934 | |||||
Test Accuracy: 78.130841 % | |||||
...... | |||||
Epoch [20/20], Iter [100/192] Loss: 0.0364 | |||||
Test Accuracy: 81.495327 % | |||||
Best Accuracy: 82.616822 % | |||||
Best Model: models/cnn.pkl | |||||
``` | |||||
## Hyperparameters | |||||
According to the paper and experiment, I set: | |||||
|Epoch|Kernel Size|dropout|learning rate|batch size| | |||||
|---|---|---|---|---| | |||||
|20|\(h,300,100\)|0.5|0.0001|50| | |||||
h = [3,4,5] | |||||
If the accuracy is not improved, the learning rate will \*0.8. | |||||
## Result | |||||
I just tried one dataset : MR. (Other 6 dataset in paper SST-1, SST-2, TREC, CR, MPQA) | |||||
There are four models in paper: CNN-rand, CNN-static, CNN-non-static, CNN-multichannel. | |||||
I have tried CNN-non-static:A model with pre-trained vectors from word2vec. | |||||
All words—including the unknown ones that are randomly initialized and the pretrained vectors are fine-tuned for each task | |||||
(which has almost the best performance and the most difficut to implement among the four models) | |||||
|Dataset|Class Size|Best Result|Kim's Paper Result| | |||||
|---|---|---|---| | |||||
|MR|2|82.617%(CNN-non-static)|81.5%(CNN-nonstatic)| | |||||
## Reference | |||||
* [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) | |||||
* https://github.com/Shawn1993/cnn-text-classification-pytorch | |||||
* https://github.com/junwang4/CNN-sentence-classification-pytorch-2017/blob/master/utils.py | |||||
@@ -1,136 +0,0 @@ | |||||
import codecs | |||||
import random | |||||
import re | |||||
import gensim | |||||
import numpy as np | |||||
from gensim import corpora | |||||
from torch.utils.data import Dataset | |||||
def clean_str(string): | |||||
""" | |||||
Tokenization/string cleaning for all datasets except for SST. | |||||
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py | |||||
""" | |||||
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) | |||||
string = re.sub(r"\'s", " \'s", string) | |||||
string = re.sub(r"\'ve", " \'ve", string) | |||||
string = re.sub(r"n\'t", " n\'t", string) | |||||
string = re.sub(r"\'re", " \'re", string) | |||||
string = re.sub(r"\'d", " \'d", string) | |||||
string = re.sub(r"\'ll", " \'ll", string) | |||||
string = re.sub(r",", " , ", string) | |||||
string = re.sub(r"!", " ! ", string) | |||||
string = re.sub(r"\(", " \( ", string) | |||||
string = re.sub(r"\)", " \) ", string) | |||||
string = re.sub(r"\?", " \? ", string) | |||||
string = re.sub(r"\s{2,}", " ", string) | |||||
return string.strip() | |||||
def pad_sentences(sentence, padding_word=" <PAD/>"): | |||||
sequence_length = 64 | |||||
sent = sentence.split() | |||||
padded_sentence = sentence + padding_word * (sequence_length - len(sent)) | |||||
return padded_sentence | |||||
# data loader | |||||
class MRDataset(Dataset): | |||||
def __init__(self): | |||||
# load positive and negative sentenses from files | |||||
with codecs.open("./rt-polaritydata/rt-polarity.pos", encoding='ISO-8859-1') as f: | |||||
positive_examples = list(f.readlines()) | |||||
with codecs.open("./rt-polaritydata/rt-polarity.neg", encoding='ISO-8859-1') as f: | |||||
negative_examples = list(f.readlines()) | |||||
# s.strip: clear "\n"; clear_str; pad | |||||
positive_examples = [pad_sentences(clean_str(s.strip())) for s in positive_examples] | |||||
negative_examples = [pad_sentences(clean_str(s.strip())) for s in negative_examples] | |||||
self.examples = positive_examples + negative_examples | |||||
self.sentences_texts = [sample.split() for sample in self.examples] | |||||
# word dictionary | |||||
dictionary = corpora.Dictionary(self.sentences_texts) | |||||
self.word2id_dict = dictionary.token2id # transform to dict, like {"human":0, "a":1,...} | |||||
# set lables: postive is 1; negative is 0 | |||||
positive_labels = [1 for _ in positive_examples] | |||||
negative_labels = [0 for _ in negative_examples] | |||||
self.lables = positive_labels + negative_labels | |||||
examples_lables = list(zip(self.examples, self.lables)) | |||||
random.shuffle(examples_lables) | |||||
self.MRDataset_frame = examples_lables | |||||
# transform word to id | |||||
self.MRDataset_wordid = \ | |||||
[( | |||||
np.array([self.word2id_dict[word] for word in sent[0].split()], dtype=np.int64), | |||||
sent[1] | |||||
) for sent in self.MRDataset_frame] | |||||
def word_embeddings(self, path="./GoogleNews-vectors-negative300.bin/GoogleNews-vectors-negative300.bin"): | |||||
# establish from google | |||||
model = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True) | |||||
print('Please wait ... (it could take a while to load the file : {})'.format(path)) | |||||
word_dict = self.word2id_dict | |||||
embedding_weights = np.random.uniform(-0.25, 0.25, (len(self.word2id_dict), 300)) | |||||
for word in word_dict: | |||||
word_id = word_dict[word] | |||||
if word in model.wv.vocab: | |||||
embedding_weights[word_id, :] = model[word] | |||||
return embedding_weights | |||||
def __len__(self): | |||||
return len(self.MRDataset_frame) | |||||
def __getitem__(self, idx): | |||||
sample = self.MRDataset_wordid[idx] | |||||
return sample | |||||
def getsent(self, idx): | |||||
sample = self.MRDataset_wordid[idx][0] | |||||
return sample | |||||
def getlabel(self, idx): | |||||
label = self.MRDataset_wordid[idx][1] | |||||
return label | |||||
def word2id(self): | |||||
return self.word2id_dict | |||||
def id2word(self): | |||||
id2word_dict = dict([val, key] for key, val in self.word2id_dict.items()) | |||||
return id2word_dict | |||||
class train_set(Dataset): | |||||
def __init__(self, samples): | |||||
self.train_frame = samples | |||||
def __len__(self): | |||||
return len(self.train_frame) | |||||
def __getitem__(self, idx): | |||||
return self.train_frame[idx] | |||||
class test_set(Dataset): | |||||
def __init__(self, samples): | |||||
self.test_frame = samples | |||||
def __len__(self): | |||||
return len(self.test_frame) | |||||
def __getitem__(self, idx): | |||||
return self.test_frame[idx] |
@@ -1,42 +0,0 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class CNN_text(nn.Module): | |||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5, | |||||
L2_constrain=3, | |||||
pretrained_embeddings=None): | |||||
super(CNN_text, self).__init__() | |||||
self.embedding = nn.Embedding(embed_num, embed_dim) | |||||
self.dropout = nn.Dropout(dropout) | |||||
if pretrained_embeddings is not None: | |||||
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings)) | |||||
# the network structure | |||||
# Conv2d: input- N,C,H,W output- (50,100,62,1) | |||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h]) | |||||
self.fc1 = nn.Linear(len(kernel_h) * kernel_num, num_classes) | |||||
def max_pooling(self, x): | |||||
x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62) | |||||
x = F.max_pool1d(x, x.size(2)).squeeze(2) | |||||
# x.size(2)=62 squeeze: (50,100,1) -> (50,100) | |||||
return x | |||||
def forward(self, x): | |||||
x = self.embedding(x) # output: (N,H,W) = (50,64,300) | |||||
x = x.unsqueeze(1) # (N,C,H,W) | |||||
x = [F.relu(conv(x)).squeeze(3) for conv in self.conv1] # [N, C, H(50,100,62),(50,100,61),(50,100,60)] | |||||
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [N,C(50,100),(50,100),(50,100)] | |||||
x = torch.cat(x, 1) | |||||
x = self.dropout(x) | |||||
x = self.fc1(x) | |||||
return x | |||||
if __name__ == '__main__': | |||||
model = CNN_text(kernel_h=[1, 2, 3, 4], embed_num=3, embed_dim=2) | |||||
x = torch.LongTensor([[1, 2, 1, 2, 0]]) | |||||
print(model(x)) |
@@ -1,92 +0,0 @@ | |||||
import os | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from . import dataset as dst | |||||
from .model import CNN_text | |||||
# Hyper Parameters | |||||
batch_size = 50 | |||||
learning_rate = 0.0001 | |||||
num_epochs = 20 | |||||
cuda = True | |||||
# split Dataset | |||||
dataset = dst.MRDataset() | |||||
length = len(dataset) | |||||
train_dataset = dataset[:int(0.9 * length)] | |||||
test_dataset = dataset[int(0.9 * length):] | |||||
train_dataset = dst.train_set(train_dataset) | |||||
test_dataset = dst.test_set(test_dataset) | |||||
# Data Loader | |||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | |||||
batch_size=batch_size, | |||||
shuffle=True) | |||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | |||||
batch_size=batch_size, | |||||
shuffle=False) | |||||
# cnn | |||||
cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings()) | |||||
if cuda: | |||||
cnn.cuda() | |||||
# Loss and Optimizer | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) | |||||
# train and test | |||||
best_acc = None | |||||
for epoch in range(num_epochs): | |||||
# Train the Model | |||||
cnn.train() | |||||
for i, (sents, labels) in enumerate(train_loader): | |||||
sents = Variable(sents) | |||||
labels = Variable(labels) | |||||
if cuda: | |||||
sents = sents.cuda() | |||||
labels = labels.cuda() | |||||
optimizer.zero_grad() | |||||
outputs = cnn(sents) | |||||
loss = criterion(outputs, labels) | |||||
loss.backward() | |||||
optimizer.step() | |||||
if (i + 1) % 100 == 0: | |||||
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' | |||||
% (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0])) | |||||
# Test the Model | |||||
cnn.eval() | |||||
correct = 0 | |||||
total = 0 | |||||
for sents, labels in test_loader: | |||||
sents = Variable(sents) | |||||
if cuda: | |||||
sents = sents.cuda() | |||||
labels = labels.cuda() | |||||
outputs = cnn(sents) | |||||
_, predicted = torch.max(outputs.data, 1) | |||||
total += labels.size(0) | |||||
correct += (predicted == labels).sum() | |||||
acc = 100. * correct / total | |||||
print('Test Accuracy: %f %%' % (acc)) | |||||
if best_acc is None or acc > best_acc: | |||||
best_acc = acc | |||||
if os.path.exists("models") is False: | |||||
os.makedirs("models") | |||||
torch.save(cnn.state_dict(), 'models/cnn.pkl') | |||||
else: | |||||
learning_rate = learning_rate * 0.8 | |||||
print("Best Accuracy: %f %%" % best_acc) | |||||
print("Best Model: models/cnn.pkl") |
@@ -1,21 +0,0 @@ | |||||
MIT License | |||||
Copyright (c) 2017 | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
The above copyright notice and this permission notice shall be included in all | |||||
copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||||
SOFTWARE. |
@@ -1,40 +0,0 @@ | |||||
# PyTorch-Character-Aware-Neural-Language-Model | |||||
This is the PyTorch implementation of character-aware neural language model proposed in this [paper](https://arxiv.org/abs/1508.06615) by Yoon Kim. | |||||
## Requiredments | |||||
The code is run and tested with **Python 3.5.2** and **PyTorch 0.3.1**. | |||||
## HyperParameters | |||||
| HyperParam | value | | |||||
| ------ | :-------| | |||||
| LSTM batch size | 20 | | |||||
| LSTM sequence length | 35 | | |||||
| LSTM hidden units | 300 | | |||||
| epochs | 35 | | |||||
| initial learning rate | 1.0 | | |||||
| character embedding dimension | 15 | | |||||
## Demo | |||||
Train the model with split train/valid/test data. | |||||
`python train.py` | |||||
The trained model will saved in `cache/net.pkl`. | |||||
Test the model. | |||||
`python test.py` | |||||
Best result on test set: | |||||
PPl=127.2163 | |||||
cross entropy loss=4.8459 | |||||
## Acknowledgement | |||||
This implementation borrowed ideas from | |||||
https://github.com/jarfo/kchar | |||||
https://github.com/cronos123/Character-Aware-Neural-Language-Models | |||||
@@ -1,9 +0,0 @@ | |||||
PICKLE = "./save/" | |||||
def train(): | |||||
pass | |||||
if __name__ == "__main__": | |||||
train() |
@@ -1,145 +0,0 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class Highway(nn.Module): | |||||
"""Highway network""" | |||||
def __init__(self, input_size): | |||||
super(Highway, self).__init__() | |||||
self.fc1 = nn.Linear(input_size, input_size, bias=True) | |||||
self.fc2 = nn.Linear(input_size, input_size, bias=True) | |||||
def forward(self, x): | |||||
t = F.sigmoid(self.fc1(x)) | |||||
return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) | |||||
class charLM(nn.Module): | |||||
"""CNN + highway network + LSTM | |||||
# Input: | |||||
4D tensor with shape [batch_size, in_channel, height, width] | |||||
# Output: | |||||
2D Tensor with shape [batch_size, vocab_size] | |||||
# Arguments: | |||||
char_emb_dim: the size of each character's attention | |||||
word_emb_dim: the size of each word's attention | |||||
vocab_size: num of unique words | |||||
num_char: num of characters | |||||
use_gpu: True or False | |||||
""" | |||||
def __init__(self, char_emb_dim, word_emb_dim, | |||||
vocab_size, num_char, use_gpu): | |||||
super(charLM, self).__init__() | |||||
self.char_emb_dim = char_emb_dim | |||||
self.word_emb_dim = word_emb_dim | |||||
self.vocab_size = vocab_size | |||||
# char attention layer | |||||
self.char_embed = nn.Embedding(num_char, char_emb_dim) | |||||
# convolutions of filters with different sizes | |||||
self.convolutions = [] | |||||
# list of tuples: (the number of filter, width) | |||||
self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] | |||||
for out_channel, filter_width in self.filter_num_width: | |||||
self.convolutions.append( | |||||
nn.Conv2d( | |||||
1, # in_channel | |||||
out_channel, # out_channel | |||||
kernel_size=(char_emb_dim, filter_width), # (height, width) | |||||
bias=True | |||||
) | |||||
) | |||||
self.highway_input_dim = sum([x for x, y in self.filter_num_width]) | |||||
self.batch_norm = nn.BatchNorm1d(self.highway_input_dim, affine=False) | |||||
# highway net | |||||
self.highway1 = Highway(self.highway_input_dim) | |||||
self.highway2 = Highway(self.highway_input_dim) | |||||
# LSTM | |||||
self.lstm_num_layers = 2 | |||||
self.lstm = nn.LSTM(input_size=self.highway_input_dim, | |||||
hidden_size=self.word_emb_dim, | |||||
num_layers=self.lstm_num_layers, | |||||
bias=True, | |||||
dropout=0.5, | |||||
batch_first=True) | |||||
# output layer | |||||
self.dropout = nn.Dropout(p=0.5) | |||||
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | |||||
if use_gpu is True: | |||||
for x in range(len(self.convolutions)): | |||||
self.convolutions[x] = self.convolutions[x].cuda() | |||||
self.highway1 = self.highway1.cuda() | |||||
self.highway2 = self.highway2.cuda() | |||||
self.lstm = self.lstm.cuda() | |||||
self.dropout = self.dropout.cuda() | |||||
self.char_embed = self.char_embed.cuda() | |||||
self.linear = self.linear.cuda() | |||||
self.batch_norm = self.batch_norm.cuda() | |||||
def forward(self, x, hidden): | |||||
# Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] | |||||
# Return: Variable of Tensor with shape [num_words, len(word_dict)] | |||||
lstm_batch_size = x.size()[0] | |||||
lstm_seq_len = x.size()[1] | |||||
x = x.contiguous().view(-1, x.size()[2]) | |||||
# [num_seq*seq_len, max_word_len+2] | |||||
x = self.char_embed(x) | |||||
# [num_seq*seq_len, max_word_len+2, char_emb_dim] | |||||
x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) | |||||
# [num_seq*seq_len, 1, max_word_len+2, char_emb_dim] | |||||
x = self.conv_layers(x) | |||||
# [num_seq*seq_len, total_num_filters] | |||||
x = self.batch_norm(x) | |||||
# [num_seq*seq_len, total_num_filters] | |||||
x = self.highway1(x) | |||||
x = self.highway2(x) | |||||
# [num_seq*seq_len, total_num_filters] | |||||
x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | |||||
# [num_seq, seq_len, total_num_filters] | |||||
x, hidden = self.lstm(x, hidden) | |||||
# [seq_len, num_seq, hidden_size] | |||||
x = self.dropout(x) | |||||
# [seq_len, num_seq, hidden_size] | |||||
x = x.contiguous().view(lstm_batch_size * lstm_seq_len, -1) | |||||
# [num_seq*seq_len, hidden_size] | |||||
x = self.linear(x) | |||||
# [num_seq*seq_len, vocab_size] | |||||
return x, hidden | |||||
def conv_layers(self, x): | |||||
chosen_list = list() | |||||
for conv in self.convolutions: | |||||
feature_map = F.tanh(conv(x)) | |||||
# (batch_size, out_channel, 1, max_word_len-width+1) | |||||
chosen = torch.max(feature_map, 3)[0] | |||||
# (batch_size, out_channel, 1) | |||||
chosen = chosen.squeeze() | |||||
# (batch_size, out_channel) | |||||
chosen_list.append(chosen) | |||||
# (batch_size, total_num_filers) | |||||
return torch.cat(chosen_list, 1) |
@@ -1,117 +0,0 @@ | |||||
import os | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from utilities import * | |||||
def to_var(x): | |||||
if torch.cuda.is_available(): | |||||
x = x.cuda() | |||||
return Variable(x) | |||||
def test(net, data, opt): | |||||
net.eval() | |||||
test_input = torch.from_numpy(data.test_input) | |||||
test_label = torch.from_numpy(data.test_label) | |||||
num_seq = test_input.size()[0] // opt.lstm_seq_len | |||||
test_input = test_input[:num_seq * opt.lstm_seq_len, :] | |||||
# [num_seq, seq_len, max_word_len+2] | |||||
test_input = test_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2) | |||||
criterion = nn.CrossEntropyLoss() | |||||
loss_list = [] | |||||
num_hits = 0 | |||||
total = 0 | |||||
iterations = test_input.size()[0] // opt.lstm_batch_size | |||||
test_generator = batch_generator(test_input, opt.lstm_batch_size) | |||||
label_generator = batch_generator(test_label, opt.lstm_batch_size * opt.lstm_seq_len) | |||||
hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)), | |||||
to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim))) | |||||
add_loss = 0.0 | |||||
for t in range(iterations): | |||||
batch_input = test_generator.__next__() | |||||
batch_label = label_generator.__next__() | |||||
net.zero_grad() | |||||
hidden = [state.detach() for state in hidden] | |||||
test_output, hidden = net(to_var(batch_input), hidden) | |||||
test_loss = criterion(test_output, to_var(batch_label)).data | |||||
loss_list.append(test_loss) | |||||
add_loss += test_loss | |||||
print("Test Loss={0:.4f}".format(float(add_loss) / iterations)) | |||||
print("Test PPL={0:.4f}".format(float(np.exp(add_loss / iterations)))) | |||||
############################################################# | |||||
if __name__ == "__main__": | |||||
word_embed_dim = 300 | |||||
char_embedding_dim = 15 | |||||
if os.path.exists("cache/prep.pt") is False: | |||||
print("Cannot find prep.pt") | |||||
objetcs = torch.load("cache/prep.pt") | |||||
word_dict = objetcs["word_dict"] | |||||
char_dict = objetcs["char_dict"] | |||||
reverse_word_dict = objetcs["reverse_word_dict"] | |||||
max_word_len = objetcs["max_word_len"] | |||||
num_words = len(word_dict) | |||||
print("word/char dictionary built. Start making inputs.") | |||||
if os.path.exists("cache/data_sets.pt") is False: | |||||
test_text = read_data("./test.txt") | |||||
test_set = np.array(text2vec(test_text, char_dict, max_word_len)) | |||||
# Labels are next-word index in word_dict with the same length as inputs | |||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | |||||
category = {"test": test_set, "tlabel": test_label} | |||||
torch.save(category, "cache/data_sets.pt") | |||||
else: | |||||
data_sets = torch.load("cache/data_sets.pt") | |||||
test_set = data_sets["test"] | |||||
test_label = data_sets["tlabel"] | |||||
train_set = data_sets["tdata"] | |||||
train_label = data_sets["trlabel"] | |||||
DataTuple = namedtuple("DataTuple", "test_input test_label train_input train_label ") | |||||
data = DataTuple(test_input=test_set, | |||||
test_label=test_label, train_label=train_label, train_input=train_set) | |||||
print("Loaded data sets. Start building network.") | |||||
USE_GPU = True | |||||
cnn_batch_size = 700 | |||||
lstm_seq_len = 35 | |||||
lstm_batch_size = 20 | |||||
net = torch.load("cache/net.pkl") | |||||
Options = namedtuple("Options", ["cnn_batch_size", "lstm_seq_len", | |||||
"max_word_len", "lstm_batch_size", "word_embed_dim"]) | |||||
opt = Options(cnn_batch_size=lstm_seq_len * lstm_batch_size, | |||||
lstm_seq_len=lstm_seq_len, | |||||
max_word_len=max_word_len, | |||||
lstm_batch_size=lstm_batch_size, | |||||
word_embed_dim=word_embed_dim) | |||||
print("Network built. Start testing.") | |||||
test(net, data, opt) |
@@ -1,320 +0,0 @@ | |||||
no it was n't black monday | |||||
but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos | |||||
some circuit breakers installed after the october N crash failed their first test traders say unable to cool the selling panic in both stocks and futures | |||||
the N stock specialist firms on the big board floor the buyers and sellers of last resort who were criticized after the N crash once again could n't handle the selling pressure | |||||
big investment banks refused to step up to the plate to support the beleaguered floor traders by buying big blocks of stock traders say | |||||
heavy selling of standard & poor 's 500-stock index futures in chicago <unk> beat stocks downward | |||||
seven big board stocks ual amr bankamerica walt disney capital cities\/abc philip morris and pacific telesis group stopped trading and never resumed | |||||
the <unk> has already begun | |||||
the equity market was <unk> | |||||
once again the specialists were not able to handle the imbalances on the floor of the new york stock exchange said christopher <unk> senior vice president at <unk> securities corp | |||||
<unk> james <unk> chairman of specialists henderson brothers inc. it is easy to say the specialist is n't doing his job | |||||
when the dollar is in a <unk> even central banks ca n't stop it | |||||
speculators are calling for a degree of liquidity that is not there in the market | |||||
many money managers and some traders had already left their offices early friday afternoon on a warm autumn day because the stock market was so quiet | |||||
then in a <unk> plunge the dow jones industrials in barely an hour surrendered about a third of their gains this year <unk> up a 190.58-point or N N loss on the day in <unk> trading volume | |||||
<unk> trading accelerated to N million shares a record for the big board | |||||
at the end of the day N million shares were traded | |||||
the dow jones industrials closed at N | |||||
the dow 's decline was second in point terms only to the <unk> black monday crash that occurred oct. N N | |||||
in percentage terms however the dow 's dive was the <unk> ever and the sharpest since the market fell N or N N a week after black monday | |||||
the dow fell N N on black monday | |||||
shares of ual the parent of united airlines were extremely active all day friday reacting to news and rumors about the proposed $ N billion buy-out of the airline by an <unk> group | |||||
wall street 's takeover-stock speculators or risk arbitragers had placed unusually large bets that a takeover would succeed and ual stock would rise | |||||
at N p.m. edt came the <unk> news the big board was <unk> trading in ual pending news | |||||
on the exchange floor as soon as ual stopped trading we <unk> for a panic said one top floor trader | |||||
several traders could be seen shaking their heads when the news <unk> | |||||
for weeks the market had been nervous about takeovers after campeau corp. 's cash crunch spurred concern about the prospects for future highly leveraged takeovers | |||||
and N minutes after the ual trading halt came news that the ual group could n't get financing for its bid | |||||
at this point the dow was down about N points | |||||
the market <unk> | |||||
arbitragers could n't dump their ual stock but they rid themselves of nearly every rumor stock they had | |||||
for example their selling caused trading halts to be declared in usair group which closed down N N to N N delta air lines which fell N N to N N and <unk> industries which sank N to N N | |||||
these stocks eventually reopened | |||||
but as panic spread speculators began to sell blue-chip stocks such as philip morris and international business machines to offset their losses | |||||
when trading was halted in philip morris the stock was trading at N down N N while ibm closed N N lower at N | |||||
selling <unk> because of waves of automatic stop-loss orders which are triggered by computer when prices fall to certain levels | |||||
most of the stock selling pressure came from wall street professionals including computer-guided program traders | |||||
traders said most of their major institutional investors on the other hand sat tight | |||||
now at N one of the market 's post-crash reforms took hold as the s&p N futures contract had plunged N points equivalent to around a <unk> drop in the dow industrials | |||||
under an agreement signed by the big board and the chicago mercantile exchange trading was temporarily halted in chicago | |||||
after the trading halt in the s&p N pit in chicago waves of selling continued to hit stocks themselves on the big board and specialists continued to <unk> prices down | |||||
as a result the link between the futures and stock markets <unk> apart | |||||
without the <unk> of stock-index futures the barometer of where traders think the overall stock market is headed many traders were afraid to trust stock prices quoted on the big board | |||||
the futures halt was even <unk> by big board floor traders | |||||
it <unk> things up said one major specialist | |||||
this confusion effectively halted one form of program trading stock index arbitrage that closely links the futures and stock markets and has been blamed by some for the market 's big swings | |||||
in a stock-index arbitrage sell program traders buy or sell big baskets of stocks and offset the trade in futures to lock in a price difference | |||||
when the airline information came through it <unk> every model we had for the marketplace said a managing director at one of the largest program-trading firms | |||||
we did n't even get a chance to do the programs we wanted to do | |||||
but stocks kept falling | |||||
the dow industrials were down N points at N p.m. before the <unk> halt | |||||
at N p.m. at the end of the cooling off period the average was down N points | |||||
meanwhile during the the s&p trading halt s&p futures sell orders began <unk> up while stocks in new york kept falling sharply | |||||
big board chairman john j. phelan said yesterday the circuit breaker worked well <unk> | |||||
i just think it 's <unk> at this point to get into a debate if index arbitrage would have helped or hurt things | |||||
under another post-crash system big board president richard <unk> mr. phelan was flying to <unk> as the market was falling was talking on an <unk> hot line to the other exchanges the securities and exchange commission and the federal reserve board | |||||
he <unk> out at a high-tech <unk> center on the floor of the big board where he could watch <unk> on prices and pending stock orders | |||||
at about N p.m. edt s&p futures resumed trading and for a brief time the futures and stock markets started to come back in line | |||||
buyers stepped in to the futures pit | |||||
but the <unk> of s&p futures sell orders weighed on the market and the link with stocks began to fray again | |||||
at about N the s&p market <unk> to still another limit of N points down and trading was locked again | |||||
futures traders say the s&p was <unk> that the dow could fall as much as N points | |||||
during this time small investors began ringing their brokers wondering whether another crash had begun | |||||
at prudential-bache securities inc. which is trying to cater to small investors some <unk> brokers thought this would be the final <unk> | |||||
that 's when george l. ball chairman of the prudential insurance co. of america unit took to the internal <unk> system to declare that the plunge was only mechanical | |||||
i have a <unk> that this particular decline today is something more <unk> about less | |||||
it would be my <unk> to advise clients not to sell to look for an opportunity to buy mr. ball told the brokers | |||||
at merrill lynch & co. the nation 's biggest brokerage firm a news release was prepared <unk> merrill lynch comments on market drop | |||||
the release cautioned that there are significant differences between the current environment and that of october N and that there are still attractive investment opportunities in the stock market | |||||
however jeffrey b. lane president of shearson lehman hutton inc. said that friday 's plunge is going to set back relations with customers because it <unk> the concern of volatility | |||||
and i think a lot of people will <unk> on program trading | |||||
it 's going to bring the debate right back to the <unk> | |||||
as the dow average ground to its final N loss friday the s&p pit stayed locked at its <unk> trading limit | |||||
jeffrey <unk> of program trader <unk> investment group said N s&p contracts were for sale on the close the equivalent of $ N million in stock | |||||
but there were no buyers | |||||
while friday 's debacle involved mainly professional traders rather than investors it left the market vulnerable to continued selling this morning traders said | |||||
stock-index futures contracts settled at much lower prices than indexes of the stock market itself | |||||
at those levels stocks are set up to be <unk> by index arbitragers who lock in profits by buying futures when futures prices fall and simultaneously sell off stocks | |||||
but nobody knows at what level the futures and stocks will open today | |||||
the <unk> between the stock and futures markets friday will undoubtedly cause renewed debate about whether wall street is properly prepared for another crash situation | |||||
the big board 's mr. <unk> said our <unk> performance was good | |||||
but the exchange will look at the performance of all specialists in all stocks | |||||
obviously we 'll take a close look at any situation in which we think the <unk> obligations were n't met he said | |||||
see related story fed ready to <unk> big funds wsj oct. N N | |||||
but specialists complain privately that just as in the N crash the <unk> firms big investment banks that support the market by trading big blocks of stock stayed on the sidelines during friday 's <unk> | |||||
mr. phelan said it will take another day or two to analyze who was buying and selling friday | |||||
concerning your sept. N page-one article on prince charles and the <unk> it 's a few hundred years since england has been a kingdom | |||||
it 's now the united kingdom of great britain and northern ireland <unk> <unk> northern ireland scotland and oh yes england too | |||||
just thought you 'd like to know | |||||
george <unk> | |||||
ports of call inc. reached agreements to sell its remaining seven aircraft to buyers that were n't disclosed | |||||
the agreements bring to a total of nine the number of planes the travel company has sold this year as part of a restructuring | |||||
the company said a portion of the $ N million realized from the sales will be used to repay its bank debt and other obligations resulting from the currently suspended <unk> operations | |||||
earlier the company announced it would sell its aging fleet of boeing co. <unk> because of increasing maintenance costs | |||||
a consortium of private investors operating as <unk> funding co. said it has made a $ N million cash bid for most of l.j. hooker corp. 's real-estate and <unk> holdings | |||||
the $ N million bid includes the assumption of an estimated $ N million in secured liabilities on those properties according to those making the bid | |||||
the group is led by jay <unk> chief executive officer of <unk> investment corp. in <unk> and a. boyd simpson chief executive of the atlanta-based simpson organization inc | |||||
mr. <unk> 's company specializes in commercial real-estate investment and claims to have $ N billion in assets mr. simpson is a developer and a former senior executive of l.j. hooker | |||||
the assets are good but they require more money and management than can be provided in l.j. hooker 's current situation said mr. simpson in an interview | |||||
hooker 's philosophy was to build and sell | |||||
we want to build and hold | |||||
l.j. hooker based in atlanta is operating with protection from its creditors under chapter N of the u.s. bankruptcy code | |||||
its parent company hooker corp. of sydney australia is currently being managed by a court-appointed provisional <unk> | |||||
sanford <unk> chief executive of l.j. hooker said yesterday in a statement that he has not yet seen the bid but that he would review it and bring it to the attention of the creditors committee | |||||
the $ N million bid is estimated by mr. simpson as representing N N of the value of all hooker real-estate holdings in the u.s. | |||||
not included in the bid are <unk> teller or b. altman & co. l.j. hooker 's department-store chains | |||||
the offer covers the massive N <unk> forest fair mall in cincinnati the N <unk> <unk> fashion mall in columbia s.c. and the N <unk> <unk> town center mall in <unk> <unk> | |||||
the <unk> mall opened sept. N with a <unk> 's <unk> as its <unk> the columbia mall is expected to open nov. N | |||||
other hooker properties included are a <unk> office tower in <unk> atlanta expected to be completed next february vacant land sites in florida and ohio l.j. hooker international the commercial real-estate brokerage company that once did business as merrill lynch commercial real estate plus other shopping centers | |||||
the consortium was put together by <unk> <unk> the london-based investment banking company that is a subsidiary of security pacific corp | |||||
we do n't anticipate any problems in raising the funding for the bid said <unk> campbell the head of mergers and acquisitions at <unk> <unk> in an interview | |||||
<unk> <unk> is acting as the consortium 's investment bankers | |||||
according to people familiar with the consortium the bid was <unk> project <unk> a reference to the film <unk> in which a <unk> played by actress <unk> <unk> is saved from a <unk> businessman by a police officer named john <unk> | |||||
l.j. hooker was a small <unk> company based in atlanta in N when mr. simpson was hired to push it into commercial development | |||||
the company grew modestly until N when a majority position in hooker corp. was acquired by australian developer george <unk> currently hooker 's chairman | |||||
mr. <unk> <unk> to launch an ambitious but <unk> $ N billion acquisition binge that included <unk> teller and b. altman & co. as well as majority positions in merksamer jewelers a sacramento chain <unk> inc. the <unk> retailer and <unk> inc. the southeast department-store chain | |||||
eventually mr. simpson and mr. <unk> had a falling out over the direction of the company and mr. simpson said he resigned in N | |||||
since then hooker corp. has sold its interest in the <unk> chain back to <unk> 's management and is currently attempting to sell the b. altman & co. chain | |||||
in addition robert <unk> chief executive of the <unk> chain is seeking funds to buy out the hooker interest in his company | |||||
the merksamer chain is currently being offered for sale by first boston corp | |||||
reached in <unk> mr. <unk> said that he believes the various hooker <unk> can become profitable with new management | |||||
these are n't mature assets but they have the potential to be so said mr. <unk> | |||||
managed properly and with a long-term outlook these can become investment-grade quality properties | |||||
canadian <unk> production totaled N metric tons in the week ended oct. N up N N from the preceding week 's total of N tons statistics canada a federal agency said | |||||
the week 's total was up N N from N tons a year earlier | |||||
the <unk> total was N tons up N N from N tons a year earlier | |||||
the treasury plans to raise $ N million in new cash thursday by selling about $ N billion of 52-week bills and <unk> $ N billion of maturing bills | |||||
the bills will be dated oct. N and will mature oct. N N | |||||
they will be available in minimum denominations of $ N | |||||
bids must be received by N p.m. edt thursday at the treasury or at federal reserve banks or branches | |||||
as small investors <unk> their mutual funds with phone calls over the weekend big fund managers said they have a strong defense against any wave of withdrawals cash | |||||
unlike the weekend before black monday the funds were n't <unk> with heavy withdrawal requests | |||||
and many fund managers have built up cash levels and say they will be buying stock this week | |||||
at fidelity investments the nation 's largest fund company telephone volume was up sharply but it was still at just half the level of the weekend preceding black monday in N | |||||
the boston firm said <unk> redemptions were running at less than one-third the level two years ago | |||||
as of yesterday afternoon the redemptions represented less than N N of the total cash position of about $ N billion of fidelity 's stock funds | |||||
two years ago there were massive redemption levels over the weekend and a lot of fear around said c. bruce <unk> who runs fidelity investments ' $ N billion <unk> fund | |||||
this feels more like a <unk> deal | |||||
people are n't <unk> | |||||
the test may come today | |||||
friday 's stock market sell-off came too late for many investors to act | |||||
some shareholders have held off until today because any fund exchanges made after friday 's close would take place at today 's closing prices | |||||
stock fund redemptions during the N debacle did n't begin to <unk> until after the market opened on black monday | |||||
but fund managers say they 're ready | |||||
many have raised cash levels which act as a buffer against steep market declines | |||||
mario <unk> for instance holds cash positions well above N N in several of his funds | |||||
windsor fund 's john <unk> and mutual series ' michael price said they had raised their cash levels to more than N N and N N respectively this year | |||||
even peter lynch manager of fidelity 's $ N billion <unk> fund the nation 's largest stock fund built up cash to N N or $ N million | |||||
one reason is that after two years of monthly net redemptions the fund posted net inflows of money from investors in august and september | |||||
i 've let the money build up mr. lynch said who added that he has had trouble finding stocks he likes | |||||
not all funds have raised cash levels of course | |||||
as a group stock funds held N N of assets in cash as of august the latest figures available from the investment company institute | |||||
that was modestly higher than the N N and N N levels in august and september of N | |||||
also persistent redemptions would force some fund managers to dump stocks to raise cash | |||||
but a strong level of investor withdrawals is much more unlikely this time around fund managers said | |||||
a major reason is that investors already have sharply scaled back their purchases of stock funds since black monday | |||||
<unk> sales have rebounded in recent months but monthly net purchases are still running at less than half N levels | |||||
there 's not nearly as much <unk> said john <unk> chairman of vanguard group inc. a big valley forge pa. fund company | |||||
many fund managers argue that now 's the time to buy | |||||
vincent <unk> manager of the $ N billion wellington fund added to his positions in bristol-myers squibb woolworth and dun & bradstreet friday | |||||
and today he 'll be looking to buy drug stocks like eli lilly pfizer and american home products whose dividend yields have been bolstered by stock declines | |||||
fidelity 's mr. lynch for his part snapped up southern co. shares friday after the stock got <unk> | |||||
if the market drops further today he said he 'll be buying blue chips such as bristol-myers and kellogg | |||||
if they <unk> stocks like that he said it presents an opportunity that is the kind of thing you dream about | |||||
major mutual-fund groups said phone calls were <unk> at twice the normal weekend pace yesterday | |||||
but most investors were seeking share prices and other information | |||||
trading volume was only modestly higher than normal | |||||
still fund groups are n't taking any chances | |||||
they hope to avoid the <unk> phone lines and other <unk> that <unk> some fund investors in october N | |||||
fidelity on saturday opened its N <unk> investor centers across the country | |||||
the centers normally are closed through the weekend | |||||
in addition east coast centers will open at N edt this morning instead of the normal N | |||||
t. rowe price associates inc. increased its staff of phone representatives to handle investor requests | |||||
the <unk> group noted that some investors moved money from stock funds to money-market funds | |||||
but most investors seemed to be in an information mode rather than in a transaction mode said steven <unk> a vice president | |||||
and vanguard among other groups said it was adding more phone representatives today to help investors get through | |||||
in an unusual move several funds moved to calm investors with <unk> on their <unk> phone lines | |||||
we view friday 's market decline as offering us a buying opportunity as long-term investors a recording at <unk> & co. funds said over the weekend | |||||
the <unk> group had a similar recording for investors | |||||
several fund managers expect a rough market this morning before prices stabilize | |||||
some early selling is likely to stem from investors and portfolio managers who want to lock in this year 's fat profits | |||||
stock funds have averaged a staggering gain of N N through september according to lipper analytical services inc | |||||
<unk> <unk> who runs shearson lehman hutton inc. 's $ N million sector analysis portfolio predicts the market will open down at least N points on technical factors and some panic selling | |||||
but she expects prices to rebound soon and is telling investors she expects the stock market wo n't decline more than N N to N N from recent highs | |||||
this is not a major crash she said | |||||
nevertheless ms. <unk> said she was <unk> with phone calls over the weekend from nervous shareholders | |||||
half of them are really scared and want to sell she said but i 'm trying to talk them out of it | |||||
she added if they all were bullish i 'd really be upset | |||||
the backdrop to friday 's slide was <unk> different from that of the october N crash fund managers argue | |||||
two years ago unlike today the dollar was weak interest rates were rising and the market was very <unk> they say | |||||
from the investors ' standpoint institutions and individuals learned a painful lesson by selling at the lows on black monday said stephen boesel manager of the $ N million t. rowe price growth and income fund | |||||
this time i do n't think we 'll get a panic reaction | |||||
newport corp. said it expects to report <unk> earnings of between N cents and N cents a share somewhat below analysts ' estimates of N cents to N cents | |||||
the maker of scientific instruments and laser parts said orders fell below expectations in recent months | |||||
a spokesman added that sales in the current quarter will about equal the <unk> quarter 's figure when newport reported net income of $ N million or N cents a share on $ N million in sales | |||||
<unk> from the strike by N machinists union members against boeing co. reached air carriers friday as america west airlines announced it will postpone its new service out of houston because of delays in receiving aircraft from the seattle jet maker | |||||
peter <unk> vice president for planning at the phoenix ariz. carrier said in an interview that the work <unk> at boeing now entering its 13th day has caused some turmoil in our scheduling and that more than N passengers who were booked to fly out of houston on america west would now be put on other airlines | |||||
mr. <unk> said boeing told america west that the N it was supposed to get this thursday would n't be delivered until nov. N the day after the airline had been planning to <unk> service at houston with four daily flights including three <unk> to phoenix and one <unk> to las vegas | |||||
now those routes are n't expected to begin until jan | |||||
boeing is also supposed to send to america west another N <unk> aircraft as well as a N by year 's end | |||||
those too are almost certain to arrive late | |||||
at this point no other america west flights including its new service at san antonio texas newark n.j. and <unk> calif. have been affected by the delays in boeing deliveries | |||||
nevertheless the company 's reaction <unk> the <unk> effect that a huge manufacturer such as boeing can have on other parts of the economy | |||||
it also is sure to help the machinists put added pressure on the company | |||||
i just do n't feel that the company can really stand or would want a prolonged <unk> tom baker president of machinists ' district N said in an interview yesterday | |||||
i do n't think their customers would like it very much | |||||
america west though is a smaller airline and therefore more affected by the delayed delivery of a single plane than many of its competitors would be | |||||
i figure that american and united probably have such a hard time counting all the planes in their fleets they might not miss one at all mr. <unk> said | |||||
indeed a random check friday did n't seem to indicate that the strike was having much of an effect on other airline operations | |||||
southwest airlines has a boeing N set for delivery at the end of this month and expects to have the plane on time | |||||
it 's so close to completion boeing 's told us there wo n't be a problem said a southwest spokesman | |||||
a spokesman for amr corp. said boeing has assured american airlines it will deliver a N on time later this month | |||||
american is preparing to take delivery of another N in early december and N more next year and is n't anticipating any changes in that timetable | |||||
in seattle a boeing spokesman explained that the company has been in constant communication with all of its customers and that it was impossible to predict what further disruptions might be triggered by the strike | |||||
meanwhile supervisors and <unk> employees have been trying to finish some N aircraft mostly N and N jumbo jets at the company 's <unk> wash. plant that were all but completed before the <unk> | |||||
as of friday four had been delivered and a fifth plane a N was supposed to be <unk> out over the weekend to air china | |||||
no date has yet been set to get back to the bargaining table | |||||
we want to make sure they know what they want before they come back said doug hammond the federal mediator who has been in contact with both sides since the strike began | |||||
the investment community for one has been anticipating a <unk> resolution | |||||
though boeing 's stock price was battered along with the rest of the market friday it actually has risen over the last two weeks on the strength of new orders | |||||
the market has taken two views that the labor situation will get settled in the short term and that things look very <unk> for boeing in the long term said howard <unk> an analyst at <unk> j. lawrence inc | |||||
boeing 's shares fell $ N friday to close at $ N in composite trading on the new york stock exchange | |||||
but mr. baker said he thinks the earliest a pact could be struck would be the end of this month <unk> that the company and union may resume negotiations as early as this week | |||||
still he said it 's possible that the strike could last considerably longer | |||||
i would n't expect an immediate resolution to anything | |||||
last week boeing chairman frank <unk> sent striking workers a letter saying that to my knowledge boeing 's offer represents the best overall three-year contract of any major u.s. industrial firm in recent history | |||||
but mr. baker called the letter and the company 's offer of a N N wage increase over the life of the pact plus bonuses very weak | |||||
he added that the company <unk> the union 's resolve and the workers ' <unk> with being forced to work many hours overtime | |||||
in separate developments talks have broken off between machinists representatives at lockheed corp. and the <unk> calif. aerospace company | |||||
the union is continuing to work through its expired contract however | |||||
it had planned a strike vote for next sunday but that has been pushed back indefinitely | |||||
united auto workers local N which represents N workers at boeing 's helicopter unit in delaware county pa. said it agreed to extend its contract on a <unk> basis with a <unk> notification to cancel while it continues bargaining | |||||
the accord expired yesterday | |||||
and boeing on friday said it received an order from <unk> <unk> for four model N <unk> <unk> valued at a total of about $ N million | |||||
the planes long range versions of the <unk> <unk> will be delivered with <unk> & <unk> <unk> engines | |||||
<unk> & <unk> is a unit of united technologies inc | |||||
<unk> <unk> is based in amsterdam | |||||
a boeing spokeswoman said a delivery date for the planes is still being worked out for a variety of reasons but not because of the strike | |||||
<unk> <unk> contributed to this article | |||||
<unk> ltd. said its utilities arm is considering building new electric power plants some valued at more than one billion canadian dollars us$ N million in great britain and elsewhere | |||||
<unk> <unk> <unk> 's senior vice president finance said its <unk> canadian utilities ltd. unit is reviewing <unk> projects in eastern canada and conventional electric power generating plants elsewhere including britain where the british government plans to allow limited competition in electrical generation from private-sector suppliers as part of its privatization program | |||||
the projects are big | |||||
they can be c$ N billion plus mr. <unk> said | |||||
but we would n't go into them alone and canadian utilities ' equity stake would be small he said | |||||
<unk> we 'd like to be the operator of the project and a modest equity investor | |||||
our long suit is our proven ability to operate power plants he said | |||||
mr. <unk> would n't offer <unk> regarding <unk> 's proposed british project but he said it would compete for customers with two huge british power generating companies that would be formed under the country 's plan to <unk> its massive water and electric utilities | |||||
britain 's government plans to raise about # N billion $ N billion from the sale of most of its giant water and electric utilities beginning next month | |||||
the planned electric utility sale scheduled for next year is alone expected to raise # N billion making it the world 's largest public offering | |||||
under terms of the plan independent <unk> would be able to compete for N N of customers until N and for another N N between N and N | |||||
canadian utilities had N revenue of c$ N billion mainly from its natural gas and electric utility businesses in alberta where the company serves about N customers | |||||
there seems to be a move around the world to <unk> the generation of electricity mr. <unk> said and canadian utilities hopes to capitalize on it | |||||
this is a real thrust on our utility side he said adding that canadian utilities is also <unk> projects in <unk> countries though he would be specific | |||||
canadian utilities is n't alone in exploring power generation opportunities in britain in anticipation of the privatization program | |||||
we 're certainly looking at some power generating projects in england said bruce <unk> vice president corporate strategy and corporate planning with enron corp. houston a big natural gas producer and pipeline operator | |||||
mr. <unk> said enron is considering building <unk> power plants in the u.k. capable of producing about N <unk> of power at a cost of about $ N million to $ N million | |||||
pse inc. said it expects to report third earnings of $ N million to $ N million or N cents to N cents a share | |||||
in the year-ago quarter the designer and operator of <unk> and waste heat recovery plants had net income of $ N or four cents a share on revenue of about $ N million | |||||
the company said the improvement is related to additional <unk> facilities that have been put into operation | |||||
<unk> <unk> flights are $ N to paris and $ N to london | |||||
in a centennial journal article oct. N the fares were reversed | |||||
diamond <unk> offshore partners said it had discovered gas offshore louisiana | |||||
the well <unk> at a rate of N million cubic feet of gas a day through a N <unk> opening at <unk> between N and N feet | |||||
diamond <unk> is the operator with a N N interest in the well | |||||
diamond <unk> offshore 's stock rose N cents friday to close at $ N in new york stock exchange composite trading | |||||
<unk> & broad home corp. said it formed a $ N million limited partnership subsidiary to buy land in california suitable for residential development | |||||
the partnership <unk> & broad land development venture limited partnership is a N joint venture with a trust created by institutional clients of <unk> advisory corp. a unit of <unk> financial corp. a real estate advisory management and development company with offices in chicago and beverly hills calif | |||||
<unk> & broad a home building company declined to identify the institutional investors | |||||
the land to be purchased by the joint venture has n't yet received <unk> and other approvals required for development and part of <unk> & broad 's job will be to obtain such approvals | |||||
the partnership runs the risk that it may not get the approvals for development but in return it can buy land at wholesale rather than retail prices which can result in sizable savings said bruce <unk> president and chief executive officer of <unk> & broad | |||||
there are really very few companies that have adequate capital to buy properties in a raw state for cash | |||||
typically developers option property and then once they get the administrative approvals they buy it said mr. <unk> adding that he believes the joint venture is the first of its kind | |||||
we usually operate in that conservative manner | |||||
by setting up the joint venture <unk> & broad can take the more aggressive approach of buying raw land while avoiding the negative <unk> to its own balance sheet mr. <unk> said | |||||
the company is putting up only N N of the capital although it is responsible for providing management planning and processing services to the joint venture | |||||
this is one of the best ways to assure a pipeline of land to fuel our growth at a minimum risk to our company mr. <unk> said | |||||
when the price of plastics took off in N quantum chemical corp. went along for the ride | |||||
the timing of quantum 's chief executive officer john <unk> <unk> appeared to be nothing less than inspired because he had just increased quantum 's reliance on plastics | |||||
the company <unk> much of the chemical industry as annual profit grew <unk> in two years | |||||
mr. <unk> said of the boom it 's going to last a whole lot longer than anybody thinks | |||||
but now prices have <unk> and quantum 's profit is <unk> | |||||
some securities analysts are looking for no better than break-even results from the company for the third quarter compared with year-earlier profit of $ N million or $ N a share on sales of $ N million | |||||
the stock having lost nearly a quarter of its value since sept. N closed at $ N share down $ N in new york stock exchange composite trading friday | |||||
to a degree quantum represents the new times that have arrived for producers of the so-called commodity plastics that <unk> modern life | |||||
having just passed through one of the most profitable periods in their history these producers now see their prices eroding | |||||
pricing cycles to be sure are nothing new for plastics producers | |||||
and the financial decline of some looks steep only in comparison with the <unk> period that is just behind them | |||||
we were all wonderful heroes last year says an executive at one of quantum 's competitors | |||||
now we 're at the bottom of the <unk> | |||||
at quantum which is based in new york the trouble is magnified by the company 's heavy <unk> on plastics | |||||
once known as national <unk> & chemical corp. the company <unk> the wine and spirits business and <unk> more of its resources into plastics after mr. <unk> took the chief executive 's job in N | |||||
mr. <unk> N years old declined to be interviewed for this article but he has consistently argued that over the long haul across both the <unk> and the <unk> of the plastics market quantum will <unk> through its new direction | |||||
quantum 's lot is mostly tied to polyethylene <unk> used to make garbage bags milk <unk> <unk> toys and meat packaging among other items | |||||
in the u.s. polyethylene market quantum has claimed the largest share about N N | |||||
but its competitors including dow chemical co. union carbide corp. and several oil giants have much broader business interests and so are better <unk> against price swings | |||||
when the price of polyethylene moves a mere penny a pound quantum 's annual profit <unk> by about N cents a share provided no other <unk> are changing | |||||
in recent months the price of polyethylene even more than that of other commodity plastics has taken a dive | |||||
benchmark grades which still sold for as much as N cents a pound last spring have skidded to between N cents and N cents | |||||
meanwhile the price of <unk> the chemical building block of polyethylene has n't dropped nearly so fast | |||||
that <unk> <unk> quantum badly because its own plants cover only about half of its <unk> needs | |||||
by many accounts an early hint of a price rout in the making came at the start of this year | |||||
china which had been putting in huge orders for polyethylene abruptly halted them | |||||
<unk> that excess polyethylene would soon be <unk> around the world other buyers then bet that prices had peaked and so began to draw down inventories rather than order new product | |||||
kenneth mitchell director of dow 's polyethylene business says producers were surprised to learn how much inventories had swelled throughout the distribution chain as prices <unk> up | |||||
people were even <unk> bags he says | |||||
now producers hope prices have hit bottom | |||||
they recently announced increases of a few cents a pound to take effect in the next several weeks | |||||
no one knows however whether the new posted prices will stick once producers and customers start to <unk> | |||||
one <unk> is george <unk> a <unk> analyst at oppenheimer & co. and a bear on plastics stocks | |||||
noting others ' estimates of when price increases can be sustained he remarks some say october | |||||
some say november | |||||
i say N | |||||
he argues that efforts to firm up prices will be undermined by producers ' plans to expand production capacity | |||||
a quick turnaround is crucial to quantum because its cash requirements remain heavy | |||||
the company is trying to carry out a three-year $ N billion <unk> program started this year | |||||
at the same time its annual payments on long-term debt will more than double from a year ago to about $ N million largely because of debt taken on to pay a $ <unk> special dividend earlier this year | |||||
quantum described the payout at the time as a way for it to share the <unk> with its holders because its stock price was n't reflecting the huge profit increases | |||||
some analysts saw the payment as an effort also to <unk> takeover speculation | |||||
whether a cash crunch might eventually force the company to cut its quarterly dividend raised N N to N cents a share only a year ago has become a topic of intense speculation on wall street since mr. <unk> <unk> dividend questions in a sept. N meeting with analysts | |||||
some viewed his response that company directors review the dividend regularly as nothing more than the standard line from executives |
@@ -1,263 +0,0 @@ | |||||
import os | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
import torch.optim as optim | |||||
from .model import charLM | |||||
from .test import test | |||||
from .utilities import * | |||||
def preprocess(): | |||||
word_dict, char_dict = create_word_char_dict("charlm.txt", "train.txt", "test.txt") | |||||
num_words = len(word_dict) | |||||
num_char = len(char_dict) | |||||
char_dict["BOW"] = num_char + 1 | |||||
char_dict["EOW"] = num_char + 2 | |||||
char_dict["PAD"] = 0 | |||||
# dict of (int, string) | |||||
reverse_word_dict = {value: key for key, value in word_dict.items()} | |||||
max_word_len = max([len(word) for word in word_dict]) | |||||
objects = { | |||||
"word_dict": word_dict, | |||||
"char_dict": char_dict, | |||||
"reverse_word_dict": reverse_word_dict, | |||||
"max_word_len": max_word_len | |||||
} | |||||
torch.save(objects, "cache/prep.pt") | |||||
print("Preprocess done.") | |||||
def to_var(x): | |||||
if torch.cuda.is_available(): | |||||
x = x.cuda() | |||||
return Variable(x) | |||||
def train(net, data, opt): | |||||
""" | |||||
:param net: the pytorch models | |||||
:param data: numpy array | |||||
:param opt: named tuple | |||||
1. random seed | |||||
2. define local input | |||||
3. training settting: learning rate, loss, etc | |||||
4. main loop epoch | |||||
5. batchify | |||||
6. validation | |||||
7. save models | |||||
""" | |||||
torch.manual_seed(1024) | |||||
train_input = torch.from_numpy(data.train_input) | |||||
train_label = torch.from_numpy(data.train_label) | |||||
valid_input = torch.from_numpy(data.valid_input) | |||||
valid_label = torch.from_numpy(data.valid_label) | |||||
# [num_seq, seq_len, max_word_len+2] | |||||
num_seq = train_input.size()[0] // opt.lstm_seq_len | |||||
train_input = train_input[:num_seq * opt.lstm_seq_len, :] | |||||
train_input = train_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2) | |||||
num_seq = valid_input.size()[0] // opt.lstm_seq_len | |||||
valid_input = valid_input[:num_seq * opt.lstm_seq_len, :] | |||||
valid_input = valid_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2) | |||||
num_epoch = opt.epochs | |||||
num_iter_per_epoch = train_input.size()[0] // opt.lstm_batch_size | |||||
learning_rate = opt.init_lr | |||||
old_PPL = 100000 | |||||
best_PPL = 100000 | |||||
# Log-SoftMax | |||||
criterion = nn.CrossEntropyLoss() | |||||
# word_emb_dim == hidden_size / num of hidden units | |||||
hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)), | |||||
to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim))) | |||||
for epoch in range(num_epoch): | |||||
################ Validation #################### | |||||
net.eval() | |||||
loss_batch = [] | |||||
PPL_batch = [] | |||||
iterations = valid_input.size()[0] // opt.lstm_batch_size | |||||
valid_generator = batch_generator(valid_input, opt.lstm_batch_size) | |||||
vlabel_generator = batch_generator(valid_label, opt.lstm_batch_size * opt.lstm_seq_len) | |||||
for t in range(iterations): | |||||
batch_input = valid_generator.__next__() | |||||
batch_label = vlabel_generator.__next__() | |||||
hidden = [state.detach() for state in hidden] | |||||
valid_output, hidden = net(to_var(batch_input), hidden) | |||||
length = valid_output.size()[0] | |||||
# [num_sample-1, len(word_dict)] vs [num_sample-1] | |||||
valid_loss = criterion(valid_output, to_var(batch_label)) | |||||
PPL = torch.exp(valid_loss.data) | |||||
loss_batch.append(float(valid_loss)) | |||||
PPL_batch.append(float(PPL)) | |||||
PPL = np.mean(PPL_batch) | |||||
print("[epoch {}] valid PPL={}".format(epoch, PPL)) | |||||
print("valid loss={}".format(np.mean(loss_batch))) | |||||
print("PPL decrease={}".format(float(old_PPL - PPL))) | |||||
# Preserve the best models | |||||
if best_PPL > PPL: | |||||
best_PPL = PPL | |||||
torch.save(net.state_dict(), "cache/models.pt") | |||||
torch.save(net, "cache/net.pkl") | |||||
# Adjust the learning rate | |||||
if float(old_PPL - PPL) <= 1.0: | |||||
learning_rate /= 2 | |||||
print("halved lr:{}".format(learning_rate)) | |||||
old_PPL = PPL | |||||
################################################## | |||||
#################### Training #################### | |||||
net.train() | |||||
optimizer = optim.SGD(net.parameters(), | |||||
lr=learning_rate, | |||||
momentum=0.85) | |||||
# split the first dim | |||||
input_generator = batch_generator(train_input, opt.lstm_batch_size) | |||||
label_generator = batch_generator(train_label, opt.lstm_batch_size * opt.lstm_seq_len) | |||||
for t in range(num_iter_per_epoch): | |||||
batch_input = input_generator.__next__() | |||||
batch_label = label_generator.__next__() | |||||
# detach hidden state of LSTM from last batch | |||||
hidden = [state.detach() for state in hidden] | |||||
output, hidden = net(to_var(batch_input), hidden) | |||||
# [num_word, vocab_size] | |||||
loss = criterion(output, to_var(batch_label)) | |||||
net.zero_grad() | |||||
loss.backward() | |||||
torch.nn.utils.clip_grad_norm(net.parameters(), 5, norm_type=2) | |||||
optimizer.step() | |||||
if (t + 1) % 100 == 0: | |||||
print("[epoch {} step {}] train loss={}, Perplexity={}".format(epoch + 1, | |||||
t + 1, float(loss.data), | |||||
float(np.exp(loss.data)))) | |||||
torch.save(net.state_dict(), "cache/models.pt") | |||||
print("Training finished.") | |||||
################################################################ | |||||
if __name__ == "__main__": | |||||
word_embed_dim = 300 | |||||
char_embedding_dim = 15 | |||||
if os.path.exists("cache/prep.pt") is False: | |||||
preprocess() | |||||
objetcs = torch.load("cache/prep.pt") | |||||
word_dict = objetcs["word_dict"] | |||||
char_dict = objetcs["char_dict"] | |||||
reverse_word_dict = objetcs["reverse_word_dict"] | |||||
max_word_len = objetcs["max_word_len"] | |||||
num_words = len(word_dict) | |||||
print("word/char dictionary built. Start making inputs.") | |||||
if os.path.exists("cache/data_sets.pt") is False: | |||||
train_text = read_data("./train.txt") | |||||
valid_text = read_data("./charlm.txt") | |||||
test_text = read_data("./test.txt") | |||||
train_set = np.array(text2vec(train_text, char_dict, max_word_len)) | |||||
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len)) | |||||
test_set = np.array(text2vec(test_text, char_dict, max_word_len)) | |||||
# Labels are next-word index in word_dict with the same length as inputs | |||||
train_label = np.array([word_dict[w] for w in train_text[1:]] + [word_dict[train_text[-1]]]) | |||||
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]]) | |||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | |||||
category = {"tdata": train_set, "vdata": valid_set, "test": test_set, | |||||
"trlabel": train_label, "vlabel": valid_label, "tlabel": test_label} | |||||
torch.save(category, "cache/data_sets.pt") | |||||
else: | |||||
data_sets = torch.load("cache/data_sets.pt") | |||||
train_set = data_sets["tdata"] | |||||
valid_set = data_sets["vdata"] | |||||
test_set = data_sets["test"] | |||||
train_label = data_sets["trlabel"] | |||||
valid_label = data_sets["vlabel"] | |||||
test_label = data_sets["tlabel"] | |||||
DataTuple = namedtuple("DataTuple", | |||||
"train_input train_label valid_input valid_label test_input test_label") | |||||
data = DataTuple(train_input=train_set, | |||||
train_label=train_label, | |||||
valid_input=valid_set, | |||||
valid_label=valid_label, | |||||
test_input=test_set, | |||||
test_label=test_label) | |||||
print("Loaded data sets. Start building network.") | |||||
USE_GPU = True | |||||
cnn_batch_size = 700 | |||||
lstm_seq_len = 35 | |||||
lstm_batch_size = 20 | |||||
# cnn_batch_size == lstm_seq_len * lstm_batch_size | |||||
net = charLM(char_embedding_dim, | |||||
word_embed_dim, | |||||
num_words, | |||||
len(char_dict), | |||||
use_gpu=USE_GPU) | |||||
for param in net.parameters(): | |||||
nn.init.uniform(param.data, -0.05, 0.05) | |||||
Options = namedtuple("Options", [ | |||||
"cnn_batch_size", "init_lr", "lstm_seq_len", | |||||
"max_word_len", "lstm_batch_size", "epochs", | |||||
"word_embed_dim"]) | |||||
opt = Options(cnn_batch_size=lstm_seq_len * lstm_batch_size, | |||||
init_lr=1.0, | |||||
lstm_seq_len=lstm_seq_len, | |||||
max_word_len=max_word_len, | |||||
lstm_batch_size=lstm_batch_size, | |||||
epochs=35, | |||||
word_embed_dim=word_embed_dim) | |||||
print("Network built. Start training.") | |||||
# You can stop training anytime by "ctrl+C" | |||||
try: | |||||
train(net, data, opt) | |||||
except KeyboardInterrupt: | |||||
print('-' * 89) | |||||
print('Exiting from training early') | |||||
torch.save(net, "cache/net.pkl") | |||||
print("save net") | |||||
test(net, data, opt) |
@@ -1,360 +0,0 @@ | |||||
aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter | |||||
pierre <unk> N years old will join the board as a nonexecutive director nov. N | |||||
mr. <unk> is chairman of <unk> n.v. the dutch publishing group | |||||
rudolph <unk> N years old and former chairman of consolidated gold fields plc was named a nonexecutive director of this british industrial conglomerate | |||||
a form of asbestos once used to make kent cigarette filters has caused a high percentage of cancer deaths among a group of workers exposed to it more than N years ago researchers reported | |||||
the asbestos fiber <unk> is unusually <unk> once it enters the <unk> with even brief exposures to it causing symptoms that show up decades later researchers said | |||||
<unk> inc. the unit of new york-based <unk> corp. that makes kent cigarettes stopped using <unk> in its <unk> cigarette filters in N | |||||
although preliminary findings were reported more than a year ago the latest results appear in today 's new england journal of medicine a forum likely to bring new attention to the problem | |||||
a <unk> <unk> said this is an old story | |||||
we 're talking about years ago before anyone heard of asbestos having any questionable properties | |||||
there is no asbestos in our products now | |||||
neither <unk> nor the researchers who studied the workers were aware of any research on smokers of the kent cigarettes | |||||
we have no useful information on whether users are at risk said james a. <unk> of boston 's <unk> cancer institute | |||||
dr. <unk> led a team of researchers from the national cancer institute and the medical schools of harvard university and boston university | |||||
the <unk> spokeswoman said asbestos was used in very modest amounts in making paper for the filters in the early 1950s and replaced with a different type of <unk> in N | |||||
from N to N N billion kent cigarettes with the filters were sold the company said | |||||
among N men who worked closely with the substance N have died more than three times the expected number | |||||
four of the five surviving workers have <unk> diseases including three with recently <unk> cancer | |||||
the total of N deaths from malignant <unk> lung cancer and <unk> was far higher than expected the researchers said | |||||
the <unk> rate is a striking finding among those of us who study <unk> diseases said dr. <unk> | |||||
the percentage of lung cancer deaths among the workers at the west <unk> mass. paper factory appears to be the highest for any asbestos workers studied in western industrialized countries he said | |||||
the plant which is owned by <unk> & <unk> co. was under contract with <unk> to make the cigarette filters | |||||
the finding probably will support those who argue that the u.s. should regulate the class of asbestos including <unk> more <unk> than the common kind of asbestos <unk> found in most schools and other buildings dr. <unk> said | |||||
the u.s. is one of the few industrialized nations that does n't have a higher standard of regulation for the smooth <unk> fibers such as <unk> that are classified as <unk> according to <unk> t. <unk> a professor of <unk> at the university of vermont college of medicine | |||||
more common <unk> fibers are <unk> and are more easily rejected by the body dr. <unk> explained | |||||
in july the environmental protection agency imposed a gradual ban on virtually all uses of asbestos | |||||
by N almost all remaining uses of <unk> asbestos will be outlawed | |||||
about N workers at a factory that made paper for the kent filters were exposed to asbestos in the 1950s | |||||
areas of the factory were particularly dusty where the <unk> was used | |||||
workers dumped large <unk> <unk> of the imported material into a huge <unk> poured in cotton and <unk> fibers and <unk> mixed the dry fibers in a process used to make filters | |||||
workers described clouds of blue dust that hung over parts of the factory even though <unk> fans <unk> the area | |||||
there 's no question that some of those workers and managers contracted <unk> diseases said <unk> phillips vice president of human resources for <unk> & <unk> | |||||
but you have to recognize that these events took place N years ago | |||||
it has no bearing on our work force today | |||||
yields on money-market mutual funds continued to slide amid signs that portfolio managers expect further declines in interest rates | |||||
the average seven-day compound yield of the N taxable funds tracked by <unk> 's money fund report eased a fraction of a percentage point to N N from N N for the week ended tuesday | |||||
compound yields assume reinvestment of dividends and that the current yield continues for a year | |||||
average maturity of the funds ' investments <unk> by a day to N days the longest since early august according to donoghue 's | |||||
longer maturities are thought to indicate declining interest rates because they permit portfolio managers to retain relatively higher rates for a longer period | |||||
shorter maturities are considered a sign of rising rates because portfolio managers can capture higher rates sooner | |||||
the average maturity for funds open only to institutions considered by some to be a stronger indicator because those managers watch the market closely reached a high point for the year N days | |||||
nevertheless said <unk> <unk> <unk> editor of money fund report yields may <unk> up again before they <unk> down because of recent rises in short-term interest rates | |||||
the yield on six-month treasury bills sold at monday 's auction for example rose to N N from N N | |||||
despite recent declines in yields investors continue to pour cash into money funds | |||||
assets of the N taxable funds grew by $ N billion during the latest week to $ N billion | |||||
typically money-fund yields beat comparable short-term investments because portfolio managers can vary maturities and go after the highest rates | |||||
the top money funds are currently yielding well over N N | |||||
dreyfus world-wide dollar the <unk> fund had a seven-day compound yield of N N during the latest week down from N N a week earlier | |||||
it invests heavily in dollar-denominated securities overseas and is currently <unk> management fees which boosts its yield | |||||
the average seven-day simple yield of the N funds was N N down from N N | |||||
the 30-day simple yield fell to an average N N from N N the 30-day compound yield slid to an average N N from N N | |||||
j.p. <unk> vice chairman of <unk> grace & co. which holds a N N interest in this <unk> company was elected a director | |||||
he succeeds <unk> d. <unk> formerly a <unk> grace vice chairman who resigned | |||||
<unk> grace holds three of grace energy 's seven board seats | |||||
pacific first financial corp. said shareholders approved its acquisition by royal <unk> ltd. of toronto for $ N a share or $ N million | |||||
the thrift holding company said it expects to obtain regulatory approval and complete the transaction by year-end | |||||
<unk> international inc. said its <unk> & <unk> unit completed the sale of its <unk> controls operations to <unk> s.p a. for $ N million | |||||
<unk> is an italian state-owned holding company with interests in the mechanical engineering industry | |||||
<unk> controls based in <unk> ohio makes computerized industrial controls systems | |||||
it employs N people and has annual revenue of about $ N million | |||||
the federal government suspended sales of u.s. savings bonds because congress has n't lifted the ceiling on government debt | |||||
until congress acts the government has n't any authority to issue new debt obligations of any kind the treasury said | |||||
the government 's borrowing authority dropped at midnight tuesday to $ N trillion from $ N trillion | |||||
legislation to lift the debt ceiling is <unk> in the fight over cutting capital-gains taxes | |||||
the house has voted to raise the ceiling to $ N trillion but the senate is n't expected to act until next week at the earliest | |||||
the treasury said the u.s. will default on nov. N if congress does n't act by then | |||||
clark j. <unk> was named senior vice president and general manager of this u.s. sales and marketing arm of japanese auto maker mazda motor corp | |||||
in the new position he will oversee mazda 's u.s. sales service parts and marketing operations | |||||
previously mr. <unk> N years old was general marketing manager of chrysler corp. 's chrysler division | |||||
he had been a sales and marketing executive with chrysler for N years | |||||
when it 's time for their <unk> <unk> the nation 's manufacturing <unk> typically jet off to the <unk> <unk> of resort towns like <unk> <unk> and hot springs | |||||
not this year | |||||
the national association of manufacturers settled on the <unk> capital of indianapolis for its fall board meeting | |||||
and the city decided to treat its guests more like royalty or rock stars than factory owners | |||||
the idea of course to prove to N corporate decision makers that the buckle on the <unk> belt is n't so <unk> after all that it 's a good place for a company to expand | |||||
on the receiving end of the message were officials from giants like du pont and <unk> along with lesser <unk> like <unk> steel and the valley queen <unk> factory | |||||
for <unk> the executives joined mayor william h. <unk> iii for an evening of the indianapolis <unk> <unk> and a guest <unk> victor <unk> | |||||
champagne and <unk> followed | |||||
the next morning with a police <unk> <unk> of executives and their wives <unk> to the indianapolis motor <unk> <unk> by traffic or red lights | |||||
the governor could n't make it so the <unk> governor welcomed the special guests | |||||
a buffet breakfast was held in the museum where food and drinks are banned to everyday visitors | |||||
then in the guests ' honor the <unk> <unk> out four drivers crews and even the official indianapolis N announcer for a <unk> exhibition race | |||||
after the race fortune N executives <unk> like <unk> over the cars and drivers | |||||
no <unk> the drivers pointed out they still had space on their machines for another sponsor 's name or two | |||||
back downtown the <unk> squeezed in a few meetings at the hotel before <unk> the buses again | |||||
this time it was for dinner and <unk> a block away | |||||
under the stars and <unk> of the <unk> indiana <unk> <unk> nine of the hottest chefs in town fed them indiana <unk> <unk> <unk> <unk> <unk> <unk> and <unk> <unk> with a <unk> <unk> | |||||
knowing a <unk> and free <unk> when they eat one the executives gave the chefs a standing <unk> | |||||
more than a few <unk> say the <unk> treatment <unk> them to return to a <unk> city for future meetings | |||||
but for now they 're looking forward to their winter meeting <unk> in february | |||||
south korea registered a trade deficit of $ N million in october reflecting the country 's economic <unk> according to government figures released wednesday | |||||
preliminary <unk> by the trade and industry ministry showed another trade deficit in october the fifth monthly setback this year casting a cloud on south korea 's <unk> economy | |||||
exports in october stood at $ N billion a mere N N increase from a year earlier while imports increased sharply to $ N billion up N N from last october | |||||
south korea 's economic boom which began in N stopped this year because of prolonged labor disputes trade conflicts and sluggish exports | |||||
government officials said exports at the end of the year would remain under a government target of $ N billion | |||||
despite the gloomy forecast south korea has recorded a trade surplus of $ N million so far this year | |||||
from january to october the nation 's accumulated exports increased N N from the same period last year to $ N billion | |||||
imports were at $ N billion up N N | |||||
newsweek trying to keep pace with rival time magazine announced new advertising rates for N and said it will introduce a new incentive plan for advertisers | |||||
the new ad plan from newsweek a unit of the washington post co. is the second incentive plan the magazine has offered advertisers in three years | |||||
plans that give advertisers discounts for maintaining or increasing ad spending have become permanent <unk> at the news <unk> and underscore the fierce competition between newsweek time warner inc. 's time magazine and <unk> b. <unk> 's u.s. news & world report | |||||
alan <unk> recently named newsweek president said newsweek 's ad rates would increase N N in january | |||||
a full <unk> page in newsweek will cost $ N | |||||
in mid-october time magazine lowered its guaranteed circulation rate base for N while not increasing ad page rates with a lower circulation base time 's ad rate will be effectively N N higher per subscriber a full page in time costs about $ N | |||||
u.s. news has yet to announce its N ad rates | |||||
newsweek said it will introduce the circulation credit plan which <unk> space credits to advertisers on renewal advertising | |||||
the magazine will reward with page bonuses advertisers who in N meet or exceed their N spending as long as they spent $ N in N and $ N in N | |||||
mr. <unk> said the plan is not an attempt to shore up a decline in ad pages in the first nine months of N newsweek 's ad pages totaled N a drop of N N from last year according to publishers information bureau | |||||
what matters is what advertisers are paying per page and in that department we are doing fine this fall said mr. <unk> | |||||
both newsweek and u.s. news have been gaining circulation in recent years without heavy use of electronic <unk> to subscribers such as telephones or watches | |||||
however none of the big three <unk> recorded circulation gains recently | |||||
according to audit bureau of <unk> time the largest <unk> had average circulation of N a decrease of N N | |||||
newsweek 's circulation for the first six months of N was N flat from the same period last year | |||||
u.s. news ' circulation in the same time was N down N N | |||||
new england electric system bowed out of the bidding for public service co. of new hampshire saying that the risks were too high and the potential <unk> too far in the future to justify a higher offer | |||||
the move leaves united illuminating co. and northeast utilities as the remaining outside bidders for ps of new hampshire which also has proposed an internal reorganization plan in chapter N bankruptcy proceedings under which it would remain an independent company | |||||
new england electric based in <unk> mass. had offered $ N billion to acquire ps of new hampshire well below the $ N billion value united illuminating places on its bid and the $ N billion northeast says its bid is worth | |||||
united illuminating is based in new haven conn. and northeast is based in hartford conn | |||||
ps of new hampshire <unk> n.h. values its internal reorganization plan at about $ N billion | |||||
john rowe president and chief executive officer of new england electric said the company 's return on equity could suffer if it made a higher bid and its forecasts related to ps of new hampshire such as growth in electricity demand and improved operating <unk> did n't come true | |||||
when we <unk> raising our bid the risks seemed substantial and persistent over the next five years and the rewards seemed a long way out | |||||
that got hard to take he added | |||||
mr. rowe also noted that political concerns also worried new england electric | |||||
no matter who owns ps of new hampshire after it emerges from bankruptcy proceedings its rates will be among the highest in the nation he said | |||||
that attracts attention | |||||
it was just another one of the risk factors that led to the company 's decision to withdraw from the bidding he added | |||||
wilbur ross jr. of rothschild inc. the financial adviser to the troubled company 's equity holders said the withdrawal of new england electric might speed up the reorganization process | |||||
the fact that new england proposed lower rate increases N N over seven years against around N N boosts proposed by the other two outside bidders complicated negotiations with state officials mr. ross asserted | |||||
now the field is less <unk> he added | |||||
separately the federal energy regulatory commission turned down for now a request by northeast seeking approval of its possible purchase of ps of new hampshire | |||||
northeast said it would <unk> its request and still hopes for an <unk> review by the ferc so that it could complete the purchase by next summer if its bid is the one approved by the bankruptcy court | |||||
ps of new hampshire shares closed yesterday at $ N off N cents in new york stock exchange composite trading | |||||
norman <unk> N years old and former president and chief operating officer of toys r us inc. and frederick <unk> jr. N chairman of <unk> banking corp. were elected directors of this consumer electronics and appliances retailing chain | |||||
they succeed daniel m. <unk> retired circuit city executive vice president and robert r. <unk> u.s. treasury undersecretary on the <unk> board | |||||
commonwealth edison co. was ordered to refund about $ N million to its current and former <unk> for illegal rates collected for cost overruns on a nuclear power plant | |||||
the refund was about $ N million more than previously ordered by the illinois commerce commission and trade groups said it may be the largest ever required of a state or local utility | |||||
state court judge richard curry ordered edison to make average refunds of about $ N to $ N each to edison customers who have received electric service since april N including about two million customers who have moved during that period | |||||
judge curry ordered the refunds to begin feb. N and said that he would n't <unk> any appeals or other attempts to block his order by commonwealth edison | |||||
the refund pool may not be held <unk> through another round of appeals judge curry said | |||||
commonwealth edison said it is already appealing the underlying commission order and is considering appealing judge curry 's order | |||||
the exact amount of the refund will be determined next year based on actual <unk> made until dec. N of this year | |||||
commonwealth edison said the ruling could force it to slash its N earnings by $ N a share | |||||
for N commonwealth edison reported earnings of $ N million or $ N a share | |||||
a commonwealth edison spokesman said that tracking down the two million customers whose addresses have changed during the past N N years would be an administrative nightmare | |||||
in new york stock exchange composite trading yesterday commonwealth edison closed at $ N down N cents | |||||
the $ N billion <unk> N plant near <unk> ill. was completed in N | |||||
in a disputed N ruling the commerce commission said commonwealth edison could raise its electricity rates by $ N million to pay for the plant | |||||
but state courts upheld a challenge by consumer groups to the commission 's rate increase and found the rates illegal | |||||
the illinois supreme court ordered the commission to audit commonwealth edison 's construction expenses and refund any <unk> expenses | |||||
the utility has been collecting for the plant 's construction cost from its N million customers subject to a refund since N | |||||
in august the commission ruled that between $ N million and $ N million of the plant 's construction cost was <unk> and should be <unk> plus interest | |||||
in his ruling judge curry added an additional $ N million to the commission 's calculations | |||||
last month judge curry set the interest rate on the refund at N N | |||||
commonwealth edison now faces an additional <unk> refund on its <unk> rate <unk> <unk> that the illinois appellate court has estimated at $ N million | |||||
and consumer groups hope that judge curry 's <unk> N order may set a precedent for a second nuclear rate case involving commonwealth edison 's <unk> N plant | |||||
commonwealth edison is seeking about $ N million in rate increases to pay for <unk> N | |||||
the commission is expected to rule on the <unk> N case by year end | |||||
last year commonwealth edison had to refund $ N million for poor performance of its <unk> i nuclear plant | |||||
japan 's domestic sales of cars trucks and buses in october rose N N from a year earlier to N units a record for the month the japan automobile dealers ' association said | |||||
the strong growth followed year-to-year increases of N N in august and N N in september | |||||
the monthly sales have been setting records every month since march | |||||
october sales compared with the previous month inched down N N | |||||
sales of passenger cars grew N N from a year earlier to N units | |||||
sales of medium-sized cars which benefited from price reductions arising from introduction of the consumption tax more than doubled to N units from N in october N | |||||
texas instruments japan ltd. a unit of texas instruments inc. said it opened a plant in south korea to manufacture control devices | |||||
the new plant located in <unk> about N miles from seoul will help meet increasing and diversifying demand for control products in south korea the company said | |||||
the plant will produce control devices used in motor vehicles and household appliances | |||||
the survival of spinoff cray computer corp. as a fledgling in the supercomputer business appears to depend heavily on the creativity and <unk> of its chairman and chief designer seymour cray | |||||
not only is development of the new company 's initial machine tied directly to mr. cray so is its balance sheet | |||||
documents filed with the securities and exchange commission on the pending spinoff disclosed that cray research inc. will withdraw the almost $ N million in financing it is providing the new firm if mr. cray leaves or if the <unk> project he heads is scrapped | |||||
the documents also said that although the <unk> mr. cray has been working on the project for more than six years the cray-3 machine is at least another year away from a fully operational prototype | |||||
moreover there have been no orders for the cray-3 so far though the company says it is talking with several prospects | |||||
while many of the risks were anticipated when <unk> cray research first announced the spinoff in may the <unk> it attached to the financing had n't been made public until yesterday | |||||
we did n't have much of a choice cray computer 's chief financial officer gregory <unk> said in an interview | |||||
the theory is that seymour is the chief designer of the cray-3 and without him it could not be completed | |||||
cray research did not want to fund a project that did not include seymour | |||||
the documents also said that cray computer anticipates <unk> perhaps another $ N million in financing beginning next september | |||||
but mr. <unk> called that a <unk> scenario | |||||
the filing on the details of the spinoff caused cray research stock to jump $ N yesterday to close at $ N in new york stock exchange composite trading | |||||
analysts noted yesterday that cray research 's decision to link its $ N million <unk> note to mr. cray 's presence will complicate a valuation of the new company | |||||
it has to be considered as an additional risk for the investor said gary p. <unk> of <unk> group inc. minneapolis | |||||
cray computer will be a concept stock he said | |||||
you either believe seymour can do it again or you do n't | |||||
besides the designer 's age other risk factors for mr. cray 's new company include the cray-3 's tricky <unk> chip technology | |||||
the sec documents describe those chips which are made of <unk> <unk> as being so fragile and minute they will require special <unk> handling equipment | |||||
in addition the cray-3 will contain N processors twice as many as the largest current supercomputer | |||||
cray computer also will face intense competition not only from cray research which has about N N of the world-wide supercomputer market and which is expected to roll out the <unk> machine a direct competitor with the cray-3 in N | |||||
the spinoff also will compete with international business machines corp. and japan 's big three hitachi ltd. nec corp. and fujitsu ltd | |||||
the new company said it believes there are fewer than N potential customers for <unk> priced between $ N million and $ N million presumably the cray-3 price range | |||||
under terms of the spinoff cray research stockholders are to receive one cray computer share for every two cray research shares they own in a distribution expected to occur in about two weeks | |||||
no price for the new shares has been set | |||||
instead the companies will leave it up to the marketplace to decide | |||||
cray computer has applied to trade on nasdaq | |||||
analysts calculate cray computer 's initial book value at about $ N a share | |||||
along with the note cray research is <unk> about $ N million in assets primarily those related to the cray-3 development which has been a drain on cray research 's earnings | |||||
<unk> balance sheets clearly show why cray research favored the spinoff | |||||
without the cray-3 research and development expenses the company would have been able to report a profit of $ N million for the first half of N rather than the $ N million it posted | |||||
on the other hand had it existed then cray computer would have incurred a $ N million loss | |||||
mr. cray who could n't be reached for comment will work for the new colorado springs colo. company as an independent contractor the arrangement he had with cray research | |||||
regarded as the father of the supercomputer mr. cray was paid $ N at cray research last year | |||||
at cray computer he will be paid $ N | |||||
besides messrs. cray and <unk> other senior management at the company includes neil <unk> N president and chief executive officer joseph m. <unk> N vice president engineering malcolm a. <unk> N vice president software and douglas r. <unk> N vice president hardware | |||||
all came from cray research | |||||
cray computer which currently employs N people said it expects a work force of N by the end of N | |||||
john r. stevens N years old was named senior executive vice president and chief operating officer both new positions | |||||
he will continue to report to donald <unk> president and chief executive officer | |||||
mr. stevens was executive vice president of this <unk> holding company | |||||
arthur a. hatch N was named executive vice president of the company | |||||
he was previously president of the company 's eastern edison co. unit | |||||
john d. <unk> N was named to succeed mr. hatch as president of eastern edison | |||||
previously he was vice president of eastern edison | |||||
robert p. <unk> N was named senior vice president of eastern utilities | |||||
he was previously vice president | |||||
the u.s. claiming some success in its trade <unk> removed south korea taiwan and saudi arabia from a list of countries it is closely watching for allegedly failing to honor u.s. patents <unk> and other <unk> rights | |||||
however five other countries china thailand india brazil and mexico will remain on that so-called priority watch list as a result of an interim review u.s. trade representative carla hills announced | |||||
under the new u.s. trade law those countries could face accelerated <unk> investigations and stiff trade sanctions if they do n't improve their protection of intellectual property by next spring | |||||
mrs. hills said many of the N countries that she placed under <unk> degrees of scrutiny have made genuine progress on this touchy issue | |||||
she said there is growing <unk> around the world that <unk> of <unk> rights <unk> all trading nations and particularly the creativity and <unk> of an <unk> country 's own citizens | |||||
u.s. trade negotiators argue that countries with inadequate <unk> for <unk> rights could be hurting themselves by discouraging their own scientists and authors and by <unk> u.s. high-technology firms from investing or marketing their best products there | |||||
mrs. hills <unk> south korea for creating an <unk> task force and special enforcement teams of police officers and prosecutors trained to pursue movie and book <unk> | |||||
seoul also has instituted effective <unk> procedures to aid these teams she said | |||||
taiwan has improved its standing with the u.s. by <unk> a <unk> copyright agreement <unk> its trademark law and introducing legislation to protect foreign movie producers from unauthorized <unk> of their films | |||||
that measure could <unk> taipei 's growing number of small <unk> <unk> to pay movie producers for showing their films | |||||
saudi arabia for its part has vowed to enact a copyright law compatible with international standards and to apply the law to computer software as well as to literary works mrs. hills said | |||||
these three countries are n't completely off the hook though | |||||
they will remain on a <unk> list that includes N other countries | |||||
those countries including japan italy canada greece and spain are still of some concern to the u.s. but are deemed to pose <unk> problems for american patent and copyright owners than those on the priority list | |||||
gary hoffman a washington lawyer specializing in <unk> cases said the threat of u.s. <unk> combined with a growing recognition that protecting intellectual property is in a country 's own interest prompted the improvements made by south korea taiwan and saudi arabia | |||||
what this tells us is that u.s. trade law is working he said | |||||
he said mexico could be one of the next countries to be removed from the priority list because of its efforts to craft a new patent law | |||||
mrs. hills said that the u.s. is still concerned about disturbing developments in turkey and continuing slow progress in malaysia | |||||
she did n't elaborate although earlier u.s. trade reports have complained of videocassette <unk> in malaysia and <unk> for u.s. pharmaceutical patents in turkey | |||||
the N trade act requires mrs. hills to issue another review of the performance of these countries by april N | |||||
so far mrs. hills has n't deemed any cases bad enough to merit an accelerated investigation under the so-called special N provision of the act | |||||
argentina said it will ask creditor banks to <unk> its foreign debt of $ N billion the <unk> in the developing world | |||||
the declaration by economy minister <unk> <unk> is believed to be the first time such an action has been called for by an <unk> official of such <unk> | |||||
the latin american nation has paid very little on its debt since early last year | |||||
argentina <unk> to reach a reduction of N N in the value of its external debt mr. <unk> said through his spokesman <unk> <unk> | |||||
mr. <unk> met in august with u.s. assistant treasury secretary david mulford | |||||
<unk> negotiator carlos <unk> was in washington and new york this week to meet with banks | |||||
mr. <unk> recently has said the government of president carlos <unk> who took office july N feels a significant reduction of principal and interest is the only way the debt problem may be solved | |||||
but he has not said before that the country wants half the debt <unk> | |||||
during its centennial year the wall street journal will report events of the past century that stand as milestones of american business history | |||||
three computers that changed the face of personal computing were launched in N | |||||
that year the apple ii commodore pet and tandy <unk> came to market | |||||
the computers were crude by today 's standards | |||||
apple ii owners for example had to use their television sets as screens and <unk> data on <unk> | |||||
but apple ii was a major advance from apple i which was built in a garage by stephen <unk> and steven jobs for <unk> such as the <unk> computer club | |||||
in addition the apple ii was an affordable $ N | |||||
crude as they were these early pcs triggered explosive product development in desktop models for the home and office | |||||
big mainframe computers for business had been around for years | |||||
but the new N pcs unlike earlier <unk> types such as the <unk> <unk> and <unk> had <unk> and could store about two pages of data in their memories | |||||
current pcs are more than N times faster and have memory capacity N times greater than their N counterparts | |||||
there were many pioneer pc <unk> | |||||
william gates and paul allen in N developed an early <unk> system for pcs and gates became an industry billionaire six years after ibm adapted one of these versions in N | |||||
alan f. <unk> currently chairman of seagate technology led the team that developed the disk drives for pcs | |||||
dennis <unk> and dale <unk> two atlanta engineers were <unk> of the internal <unk> that allow pcs to share data via the telephone | |||||
ibm the world leader in computers did n't offer its first pc until august N as many other companies entered the market | |||||
today pc shipments annually total some $ N billion world-wide | |||||
<unk> <unk> & co. an australian pharmaceuticals company said its <unk> inc. affiliate acquired <unk> inc. for $ N million | |||||
<unk> is a new <unk> pharmaceuticals concern that sells products under the <unk> label | |||||
<unk> said it owns N N of <unk> 's voting stock and has an agreement to acquire an additional N N | |||||
that stake together with its convertible preferred stock holdings gives <unk> the right to increase its interest to N N of <unk> 's voting stock | |||||
oil production from australia 's bass <unk> fields will be raised by N barrels a day to about N barrels with the launch of the <unk> field the first of five small fields scheduled to be brought into production before the end of N | |||||
esso australia ltd. a unit of new york-based exxon corp. and broken hill <unk> operate the fields in a joint venture | |||||
esso said the <unk> field started production tuesday | |||||
output will be gradually increased until it reaches about N barrels a day | |||||
the field has reserves of N million barrels | |||||
reserves for the five new fields total N million barrels | |||||
the <unk> and <unk> fields are expected to start producing early next year and the <unk> and <unk> fields later next year | |||||
esso said the fields were developed after the australian government decided in N to make the first N million barrels from new fields free of <unk> tax | |||||
<unk> <unk> corp. said it completed the $ N million sale of its southern optical subsidiary to a group led by the unit 's president thomas r. sloan and other managers | |||||
following the acquisition of <unk> <unk> by a buy-out group led by shearson lehman hutton earlier this year the maker of <unk> <unk> decided to <unk> itself of certain of its <unk> businesses | |||||
the sale of southern optical is a part of the program | |||||
the white house said president bush has approved duty-free treatment for imports of certain types of watches that are n't produced in significant quantities in the u.s. the virgin islands and other u.s. <unk> | |||||
the action came in response to a petition filed by <unk> inc. for changes in the u.s. <unk> system of preferences for imports from developing nations | |||||
previously watch imports were denied such duty-free treatment | |||||
<unk> had requested duty-free treatment for many types of watches covered by N different u.s. tariff <unk> | |||||
the white house said mr. bush decided to grant duty-free status for N categories but turned down such treatment for other types of watches because of the potential for material injury to watch producers located in the u.s. and the virgin islands | |||||
<unk> is a major u.s. producer and seller of watches including <unk> <unk> watches assembled in the philippines and other developing nations covered by the u.s. tariff preferences | |||||
u.s. trade officials said the philippines and thailand would be the main beneficiaries of the president 's action | |||||
imports of the types of watches that now will be eligible for duty-free treatment totaled about $ N million in N a relatively small share of the $ N billion in u.s. watch imports that year according to an aide to u.s. trade representative carla hills | |||||
magna international inc. 's chief financial officer james mcalpine resigned and its chairman frank <unk> is stepping in to help turn the <unk> manufacturer around the company said | |||||
mr. <unk> will direct an effort to reduce overhead and curb capital spending until a more satisfactory level of profit is achieved and maintained magna said | |||||
stephen <unk> currently vice president finance will succeed mr. mcalpine | |||||
an ambitious expansion has left magna with excess capacity and a heavy debt load as the automotive industry enters a downturn | |||||
the company has reported declines in operating profit in each of the past three years despite steady sales growth | |||||
magna recently cut its quarterly dividend in half and the company 's class a shares are <unk> far below their 52-week high of N canadian dollars us$ N | |||||
on the toronto stock exchange yesterday magna shares closed up N canadian cents to c$ N | |||||
mr. <unk> founder and controlling shareholder of magna resigned as chief executive officer last year to seek unsuccessfully a seat in canada 's parliament | |||||
analysts said mr. <unk> wants to resume a more influential role in running the company | |||||
they expect him to cut costs throughout the organization | |||||
the company said mr. <unk> will personally direct the restructuring <unk> by <unk> <unk> president and chief executive | |||||
neither they nor mr. mcalpine could be reached for comment | |||||
magna said mr. mcalpine resigned to pursue a consulting career with magna as one of his clients | |||||
lord <unk> <unk> chairman of english china <unk> plc was named a nonexecutive director of this british chemical company | |||||
japanese investors nearly <unk> bought up two new mortgage <unk> mutual funds totaling $ N million the u.s. federal national mortgage association said | |||||
the purchases show the strong interest of japanese investors in u.s. <unk> instruments fannie mae 's chairman david o. maxwell said at a news conference | |||||
he said more than N N of the funds were placed with japanese institutional investors | |||||
the rest went to investors from france and hong kong | |||||
earlier this year japanese investors snapped up a similar $ N million mortgage-backed securities mutual fund | |||||
that fund was put together by blackstone group a new york investment bank | |||||
the latest two funds were assembled jointly by goldman sachs & co. of the u.s. and japan 's daiwa securities co | |||||
the new seven-year funds one offering a fixed-rate return and the other with a floating-rate return linked to the london interbank offered rate offer two key advantages to japanese investors | |||||
first they are designed to eliminate the risk of prepayment mortgage-backed securities can be retired early if interest rates decline and such prepayment forces investors to <unk> their money at lower rates | |||||
second they channel monthly mortgage payments into semiannual payments reducing the administrative burden on investors | |||||
by addressing those problems mr. maxwell said the new funds have become extremely attractive to japanese and other investors outside the u.s. | |||||
such devices have boosted japanese investment in mortgage-backed securities to more than N N of the $ N billion in such instruments outstanding and their purchases are growing at a rapid rate | |||||
they also have become large purchasers of fannie mae 's corporate debt buying $ N billion in fannie mae bonds during the first nine months of the year or almost a <unk> of the total amount issued | |||||
james l. <unk> <unk> executive vice president was named a director of this oil concern expanding the board to N members | |||||
ltv corp. said a federal bankruptcy court judge agreed to extend until march N N the period in which the steel aerospace and energy products company has the exclusive right to file a reorganization plan | |||||
the company is operating under chapter N of the federal bankruptcy code giving it court protection from creditors ' lawsuits while it attempts to work out a plan to pay its debts | |||||
italian chemical giant montedison <unk> through its montedison acquisition n.v. indirect unit began its $ <unk> tender offer for all the common shares outstanding of erbamont n.v. a maker of pharmaceuticals incorporated in the netherlands | |||||
the offer advertised in today 's editions of the wall street journal is scheduled to expire at the end of november | |||||
montedison currently owns about N N of erbamont 's common shares outstanding | |||||
the offer is being launched <unk> to a previously announced agreement between the companies | |||||
japan 's reserves of gold convertible foreign currencies and special drawing rights fell by a hefty $ N billion in october to $ N billion the finance ministry said | |||||
the total marks the sixth consecutive monthly decline | |||||
the <unk> downturn reflects the intensity of bank of japan <unk> intervention since june when the u.s. currency temporarily surged above the N yen level | |||||
the announcement follows a sharper $ N billion decline in the country 's foreign reserves in september to $ N billion | |||||
pick a country any country | |||||
it 's the latest investment craze sweeping wall street a rash of new closed-end country funds those publicly traded portfolios that invest in stocks of a single foreign country | |||||
no fewer than N country funds have been launched or registered with regulators this year triple the level of all of N according to charles e. simon & co. a washington-based research firm | |||||
the turf recently has ranged from chile to <unk> to portugal | |||||
next week the philippine fund 's launch will be capped by a visit by philippine president <unk> aquino the first time a head of state has kicked off an issue at the big board here | |||||
the next province | |||||
anything 's possible how about the new guinea fund <unk> george foot a managing partner at <unk> management associates of <unk> mass | |||||
the recent explosion of country funds <unk> the closed-end fund mania of the 1920s mr. foot says when narrowly focused funds grew wildly popular | |||||
they fell into <unk> after the N crash | |||||
unlike traditional <unk> mutual funds most of these <unk> portfolios are the closed-end type issuing a fixed number of shares that trade publicly | |||||
the surge brings to nearly N the number of country funds that are or soon will be listed in new york or london | |||||
these funds now account for several billions of dollars in assets | |||||
people are looking to stake their claims now before the number of available nations runs out says michael porter an analyst at smith barney harris upham & co. new york | |||||
behind all the <unk> is some <unk> competition | |||||
as individual investors have turned away from the stock market over the years securities firms have scrambled to find new products that brokers find easy to sell | |||||
and the firms are stretching their <unk> far and wide to do it | |||||
financial planners often urge investors to diversify and to hold a <unk> of international securities | |||||
and many emerging markets have <unk> more mature markets such as the u.s. and japan | |||||
country funds offer an easy way to get a taste of foreign stocks without the hard research of seeking out individual companies | |||||
but it does n't take much to get burned | |||||
political and currency gyrations can <unk> the funds | |||||
another concern the funds ' share prices tend to swing more than the broader market | |||||
when the stock market dropped nearly N N oct. N for instance the mexico fund plunged about N N and the spain fund fell N N | |||||
and most country funds were clobbered more than most stocks after the N crash | |||||
what 's so wild about the funds ' frenzy right now is that many are trading at historically fat premiums to the value of their underlying portfolios | |||||
after trading at an average discount of more than N N in late N and part of last year country funds currently trade at an average premium of N N | |||||
the reason share prices of many of these funds this year have climbed much more sharply than the foreign stocks they hold | |||||
it 's probably worth paying a premium for funds that invest in markets that are partially closed to foreign investors such as south korea some specialists say | |||||
but some european funds recently have skyrocketed spain fund has surged to a startling N N premium | |||||
it has been targeted by japanese investors as a good long-term play tied to N 's european economic integration | |||||
and several new funds that are n't even fully invested yet have jumped to trade at big premiums | |||||
i 'm very alarmed to see these rich <unk> says smith barney 's mr. porter | |||||
the newly <unk> premiums reflect the increasingly global marketing of some country funds mr. porter suggests | |||||
unlike many u.s. investors those in asia or europe seeking <unk> exposure may be less <unk> to paying higher prices for country funds | |||||
there may be an international viewpoint cast on the funds listed here mr. porter says | |||||
nonetheless plenty of u.s. analysts and money managers are <unk> at the <unk> trading levels of some country funds | |||||
they argue that u.s. investors often can buy american depositary receipts on the big stocks in many funds these so-called adrs represent shares of foreign companies traded in the u.s. | |||||
that way investors can essentially buy the funds without paying the premium | |||||
for people who insist on jumping in now to buy the funds <unk> 's mr. foot says the only advice i have for these folks is that those who come to the party late had better be ready to leave quickly | |||||
the u.s. and soviet union are holding technical talks about possible repayment by moscow of $ N million in <unk> russian debts owed to the u.s. government the state department said |
@@ -1,82 +0,0 @@ | |||||
import torch | |||||
import torch.nn.functional as F | |||||
def batch_generator(x, batch_size): | |||||
# x: [num_words, in_channel, height, width] | |||||
# partitions x into batches | |||||
num_step = x.size()[0] // batch_size | |||||
for t in range(num_step): | |||||
yield x[t * batch_size:(t + 1) * batch_size] | |||||
def text2vec(words, char_dict, max_word_len): | |||||
""" Return list of list of int """ | |||||
word_vec = [] | |||||
for word in words: | |||||
vec = [char_dict[ch] for ch in word] | |||||
if len(vec) < max_word_len: | |||||
vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))] | |||||
vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]] | |||||
word_vec.append(vec) | |||||
return word_vec | |||||
def seq2vec(input_words, char_embedding, char_embedding_dim, char_table): | |||||
""" convert the input strings into character embeddings """ | |||||
# input_words == list of string | |||||
# char_embedding == torch.nn.Embedding | |||||
# char_embedding_dim == int | |||||
# char_table == list of unique chars | |||||
# Returns: tensor of shape [len(input_words), char_embedding_dim, max_word_len+2] | |||||
max_word_len = max([len(word) for word in input_words]) | |||||
print("max_word_len={}".format(max_word_len)) | |||||
tensor_list = [] | |||||
start_column = torch.ones(char_embedding_dim, 1) | |||||
end_column = torch.ones(char_embedding_dim, 1) | |||||
for word in input_words: | |||||
# convert string to word attention | |||||
word_encoding = char_embedding_lookup(word, char_embedding, char_table) | |||||
# add start and end columns | |||||
word_encoding = torch.cat([start_column, word_encoding, end_column], 1) | |||||
# zero-pad right columns | |||||
word_encoding = F.pad(word_encoding, (0, max_word_len - word_encoding.size()[1] + 2)).data | |||||
# create dimension | |||||
word_encoding = word_encoding.unsqueeze(0) | |||||
tensor_list.append(word_encoding) | |||||
return torch.cat(tensor_list, 0) | |||||
def read_data(file_name): | |||||
# Return: list of strings | |||||
with open(file_name, 'r') as f: | |||||
corpus = f.read().lower() | |||||
import re | |||||
corpus = re.sub(r"<unk>", "unk", corpus) | |||||
return corpus.split() | |||||
def get_char_dict(vocabulary): | |||||
# vocabulary == dict of (word, int) | |||||
# Return: dict of (char, int), starting from 1 | |||||
char_dict = dict() | |||||
count = 1 | |||||
for word in vocabulary: | |||||
for ch in word: | |||||
if ch not in char_dict: | |||||
char_dict[ch] = count | |||||
count += 1 | |||||
return char_dict | |||||
def create_word_char_dict(*file_name): | |||||
text = [] | |||||
for file in file_name: | |||||
text += read_data(file) | |||||
word_dict = {word: ix for ix, word in enumerate(set(text))} | |||||
char_dict = get_char_dict(word_dict) | |||||
return word_dict, char_dict |
@@ -1,336 +0,0 @@ | |||||
consumers may want to move their telephones a little closer to the tv set | |||||
<unk> <unk> watching abc 's monday night football can now vote during <unk> for the greatest play in N years from among four or five <unk> <unk> | |||||
two weeks ago viewers of several nbc <unk> consumer segments started calling a N number for advice on various <unk> issues | |||||
and the new syndicated reality show hard copy records viewers ' opinions for possible airing on the next day 's show | |||||
interactive telephone technology has taken a new leap in <unk> and television programmers are racing to exploit the possibilities | |||||
eventually viewers may grow <unk> with the technology and <unk> the cost | |||||
but right now programmers are figuring that viewers who are busy dialing up a range of services may put down their <unk> control <unk> and stay <unk> | |||||
we 've been spending a lot of time in los angeles talking to tv production people says mike parks president of call interactive which supplied technology for both abc sports and nbc 's consumer minutes | |||||
with the competitiveness of the television market these days everyone is looking for a way to get viewers more excited | |||||
one of the leaders behind the expanded use of N numbers is call interactive a joint venture of giants american express co. and american telephone & telegraph co | |||||
formed in august the venture <unk> at&t 's newly expanded N service with N <unk> computers in american express 's omaha neb. service center | |||||
other long-distance carriers have also begun marketing enhanced N service and special consultants are <unk> up to exploit the new tool | |||||
blair entertainment a new york firm that advises tv stations and sells ads for them has just formed a subsidiary N blair to apply the technology to television | |||||
the use of N toll numbers has been expanding rapidly in recent years | |||||
for a while <unk> <unk> lines and services that <unk> children to dial and <unk> movie or music information earned the service a somewhat <unk> image but new legal restrictions are aimed at trimming excesses | |||||
the cost of a N call is set by the <unk> abc sports for example with the cheapest starting at N cents | |||||
billing is included in a caller 's regular phone bill | |||||
from the fee the local phone company and the long-distance carrier extract their costs to carry the call passing the rest of the money to the <unk> which must cover advertising and other costs | |||||
in recent months the technology has become more flexible and able to handle much more volume | |||||
before callers of N numbers would just listen and not talk or they 'd vote yes or no by calling one of two numbers | |||||
people in the phone business call this technology N <unk> | |||||
now callers are led through complex <unk> of choices to retrieve information they want and the hardware can process N calls in N seconds | |||||
up to now N numbers have mainly been used on local tv stations and cable channels | |||||
<unk> used one to give away the house that rock star jon <unk> <unk> grew up in | |||||
for several years turner broadcasting system 's cable news network has invited viewers to respond <unk> to <unk> issues should the u.s. military intervene in panama but even the hottest <unk> on <unk> <unk> only about N calls | |||||
the newest uses of the <unk> technology demonstrate the growing variety of applications | |||||
capital cities\/abc inc. cbs inc. and general electric co. 's national broadcasting co. unit are expected to announce soon a joint campaign to raise awareness about <unk> | |||||
the subject will be written into the <unk> of prime-time shows and viewers will be given a N number to call | |||||
callers will be sent educational booklets and the call 's modest cost will be an immediate method of raising money | |||||
other network applications have very different goals | |||||
abc sports was looking for ways to lift <unk> <unk> ratings for monday night football | |||||
kurt <unk> abc sports 's marketing director says that now tens of thousands of fans call its N number each week to vote for the best <unk> return <unk> <unk> etc | |||||
profit from the calls goes to charity but abc sports also uses the calls as a sales tool after <unk> callers for voting frank <unk> offers a football <unk> for $ N and N N of callers stay on the line to order it | |||||
jackets may be sold next | |||||
meanwhile nbc sports recently began scores plus a <unk> 24-hour N line providing a complex array of scores analysis and fan news | |||||
a spokesman said its purpose is to bolster the impression that nbc sports is always there for people | |||||
nbc 's <unk> consumer minutes have increased advertiser spending during the day the network 's weakest period | |||||
each <unk> matches a sponsor and a topic on <unk> unilever n.v. 's <unk> bros. sponsors tips on diet and exercise followed by a <unk> <unk> bros. commercial | |||||
viewers can call a N number for additional advice which will be tailored to their needs based on the numbers they <unk> press one if you 're pregnant etc | |||||
if the caller stays on the line and leaves a name and address for the sponsor coupons and a newsletter will be <unk> and the sponsor will be able to gather a list of desirable potential customers | |||||
<unk> <unk> an <unk> vice president says nbc has been able to charge premium rates for this ad time | |||||
she would n't say what the premium is but it 's believed to be about N N above regular <unk> rates | |||||
we were able to get advertisers to use their promotion budget for this because they get a chance to do <unk> says ms. <unk> | |||||
and we were able to attract some new advertisers because this is something new | |||||
mr. parks of call interactive says tv executives are considering the use of N numbers for talk shows game shows news and opinion surveys | |||||
experts are predicting a big influx of new shows in N when a service called automatic number information will become widely available | |||||
this service <unk> each caller 's phone number and it can be used to generate instant mailing lists | |||||
hard copy the new syndicated tabloid show from paramount pictures will use its N number for additional purposes that include research says executive producer mark b. von s. <unk> | |||||
for a piece on local heroes of world war ii we can ask people to leave the name and number of anyone they know who won a <unk> he says | |||||
that 'll save us time and get people involved | |||||
but mr. <unk> sees much bigger changes ahead | |||||
these are just baby steps toward real interactive video which i believe will be the biggest thing yet to affect television he says | |||||
although it would be costly to shoot multiple versions tv programmers could let audiences vote on different <unk> for a movie | |||||
fox broadcasting <unk> with this concept last year when viewers of married with children voted on whether al should say i love you to <unk> on <unk> 's day | |||||
someday viewers may also choose different <unk> of news coverage | |||||
a <unk> by phone could let you decide i 'm interested in just the beginning of story no. N and i want story no. N in <unk> mr. <unk> says | |||||
you 'll start to see shows where viewers program the program | |||||
integrated resources inc. the troubled financial-services company that has been trying to sell its core companies to restructure debt said talks with a potential buyer ended | |||||
integrated did n't identify the party or say why the talks failed | |||||
last week another potential buyer <unk> financial group which had agreed in august to purchase most of integrated 's core companies for $ N million ended talks with integrated | |||||
integrated said that it would continue to pursue other alternatives to sell the five core companies and that a group of senior executives plans to make a proposal to purchase three of the companies integrated resources equity corp. resources trust co. and integrated resources asset management corp | |||||
a price was n't disclosed | |||||
integrated also said it expects to report a second-quarter loss wider than the earlier estimate of about $ N million | |||||
the company did n't disclose the new estimate but said the change was related to integrated 's failure to sell its core businesses as well as other events which it did n't detail that occurred after its announcement last week that it was in talks with the unidentified prospective buyer | |||||
meanwhile a number of top sales producers from integrated resources equity will meet this afternoon in chicago to discuss their options | |||||
the unit is a <unk> constructed group of about N independent brokers and financial planners who sell insurance annuities limited partnerships mutual funds and other investments for integrated and other firms | |||||
the sales force is viewed as a critical asset in integrated 's attempt to sell its core companies | |||||
<unk> cited concerns about how long integrated would be able to hold together the sales force as one reason its talks with integrated failed | |||||
in composite trading on the new york stock exchange yesterday integrated closed at $ N a share down N cents | |||||
integrated has been struggling to avoid a bankruptcy-law filing since june when it failed to make interest payments on nearly $ N billion of debt | |||||
integrated senior and junior creditors are owed a total of about $ N billion | |||||
an earthquake struck northern california killing more than N people | |||||
the violent temblor which lasted about N seconds and registered N on the richter scale also caused the collapse of a <unk> section of the san <unk> bay bridge and shook candlestick park | |||||
the tremor was centered near <unk> southeast of san francisco and was felt as far as N miles away | |||||
numerous injuries were reported | |||||
some buildings collapsed gas and water lines <unk> and fires <unk> | |||||
the quake which also caused damage in san jose and berkeley knocked out electricity and telephones <unk> roadways and disrupted subway service in the bay area | |||||
major injuries were n't reported at candlestick park where the third game of baseball 's world series was canceled and fans <unk> from the stadium | |||||
bush vowed to veto a bill allowing federal financing for abortions in cases of rape and incest saying tax dollars should n't be used to compound a violent act with the taking of an <unk> life | |||||
his pledge in a letter to democratic sen. byrd came ahead of an expected senate vote on spending legislation containing the provision | |||||
east germany 's politburo met amid speculation that the ruling body would oust hard-line leader honecker whose rule has been challenged by mass emigration and calls for democratic freedoms | |||||
meanwhile about N refugees flew to <unk> west germany from warsaw the first <unk> in east germany 's <unk> exodus | |||||
the world psychiatric association voted at an <unk> <unk> to <unk> <unk> the soviet union | |||||
moscow which left the group in N to avoid <unk> over allegations that political <unk> were being certified as <unk> could be suspended if the <unk> of <unk> against <unk> is discovered during a review within a year | |||||
nasa postponed the <unk> of the space shuttle atlantis because of rain near the site of the launch <unk> in <unk> <unk> fla | |||||
the flight was <unk> for today | |||||
the spacecraft 's five <unk> are to <unk> the <unk> galileo space probe on an <unk> mission to jupiter | |||||
senate democratic leaders said they had enough votes to defeat a proposed constitutional amendment to ban flag burning | |||||
the amendment is aimed at <unk> a supreme court ruling that threw out the conviction of a texas <unk> on grounds that his freedom of speech was violated | |||||
federal researchers said lung-cancer mortality rates for people under N years of age have begun to decline particularly for white males | |||||
the national cancer institute also projected that overall u.s. mortality rates from lung cancer should begin to drop in several years if cigarette smoking continues to <unk> | |||||
bush met with south korean president roh who indicated that seoul plans to further ease trade rules to ensure that its economy becomes as open as the other industrialized nations by the mid-1990s | |||||
bush assured roh that the u.s. would stand by its security commitments as long as there is a threat from communist north korea | |||||
the bush administration is seeking an understanding with congress to ease restrictions on u.s. involvement in foreign coups that might result in the death of a country 's leader | |||||
a white house spokesman said that while bush would n't alter a longstanding ban on such involvement there 's a <unk> needed on its interpretation | |||||
india 's gandhi called for parliamentary elections next month | |||||
the balloting considered a test for the prime minister and the ruling congress i party comes amid charges of <unk> leadership and government corruption | |||||
gandhi 's family has ruled independent india for all but five years of its <unk> history | |||||
the soviet union <unk> from a u.n. general assembly vote to reject israel 's credentials | |||||
it was the first time in seven years that moscow has n't joined efforts led by <unk> nations to <unk> israel from the world body and was viewed as a sign of improving <unk> ties | |||||
israel was <unk> by a vote of N with N <unk> | |||||
black activist walter sisulu said the african national congress would n't reject violence as a way to pressure the south african government into concessions that might lead to negotiations over apartheid | |||||
the <unk> sisulu was among eight black political activists freed sunday from prison | |||||
london has concluded that <unk> president <unk> was n't responsible for the execution of six british <unk> in world war ii although he probably was aware of the <unk> | |||||
the report by the defense ministry also rejected allegations that britain covered up evidence of <unk> 's activities as a german army officer | |||||
an international group approved a formal ban on ivory trade despite objections from southern african governments which threatened to find alternative channels for selling elephant <unk> | |||||
the move by the convention on trade in endangered <unk> meeting in switzerland places the elephant on the <unk> list | |||||
an <unk> in colombia killed a federal judge on a <unk> street | |||||
an <unk> caller to a local radio station said cocaine traffickers had <unk> the <unk> in <unk> for the <unk> of <unk> wanted on drug charges in the u.s. | |||||
<unk> leader <unk> met with egypt 's president <unk> and the two officials pledged to respect each other 's laws security and stability | |||||
they stopped short of <unk> diplomatic ties <unk> in N | |||||
the reconciliation talks in the <unk> desert town of <unk> followed a meeting monday in the egyptian resort of <unk> <unk> | |||||
<unk> group inc. revised its exchange offer for $ N million face amount of N N senior subordinated debt due N and extended the offer to oct. N from oct. N | |||||
the <unk> n.j. company said holders would receive for each $ N face amount $ N face amount of a new issue of secured senior subordinated notes convertible into common stock at an initial rate of $ N a share and N common shares | |||||
the new notes will bear interest at N N through july N N and thereafter at N N | |||||
under the original proposal the maker of specialty coatings and a developer of <unk> technologies offered $ N of notes due N N common shares and $ N in cash for each $ N face amount | |||||
completion of the exchange offer is subject to the tender of at least N N of the debt among other things | |||||
<unk> which said it does n't plan to further extend the offer said it received $ N face amount of debt under the original offer | |||||
the stock of ual corp. continued to be <unk> amid signs that british airways may <unk> at any <unk> <unk> of the aborted $ N billion buy-out of united airlines ' parent | |||||
ual stock plummeted a further $ N to $ N on volume of more than N million shares in new york stock exchange composite trading | |||||
the plunge followed a drop of $ N monday amid indications the takeover may take weeks to be revived | |||||
the stock has fallen $ N or N N in the three trading days since announcement of the collapse of the $ 300-a-share takeover jolted the entire stock market into its <unk> plunge ever | |||||
this is a total <unk> for takeover-stock traders one investment banker said | |||||
los angeles financier marvin davis who put united in play with a $ N billion bid two months ago last night <unk> both a ray of hope and an extra element of uncertainty by saying he remains interested in acquiring ual | |||||
but he dropped his earlier $ 300-a-share <unk> bid saying he must first explore bank financing | |||||
even as citicorp and chase manhattan corp. scrambled to line up bank financing for a revised version of the <unk> labor-management bid british airways a N N partner in the buying group indicated it wants to start from <unk> | |||||
its partners are united 's pilots who were to own N N and ual management at N N | |||||
adding <unk> to injury united 's <unk> machinists ' union which helped scuttle financing for the first bid yesterday asked ual chairman stephen wolf and other ual directors to resign | |||||
a similar demand was made by a group that represents some of united 's N <unk> employees | |||||
john <unk> machinists union general vice president attacked mr. wolf as greedy and irresponsible for pursuing the buy-out | |||||
although mr. wolf and john pope ual 's chief financial officer stood to <unk> $ N million for stock and options in the buy-out ual executives planned to reinvest only $ N million in the new company | |||||
the blue-collar machinists longtime rivals of the white-collar pilots say the <unk> would load the company with debt and weaken its finances | |||||
confusion about the two banks ' <unk> efforts to round up financing for a new bid that the ual board has n't even seen yet helped send ual stock <unk> downward | |||||
and rumors of forced selling by takeover-stock traders triggered a <unk> <unk> in the dow jones industrial average around N a.m. edt yesterday | |||||
yesterday 's selling began after a japanese news agency reported that japanese banks which balked at the first bid were ready to reject a revised version at around $ N a share or $ N billion | |||||
several reports as the day <unk> gave vague or <unk> indications about whether banks would sign up | |||||
citicorp for example said only that it had <unk> of interest of a transaction from both the borrowers and the banks but did n't have an agreement | |||||
late in the day mr. wolf issued a <unk> statement calling mr. <unk> 's blast divisive and <unk> for | |||||
but he gave few details on the progress toward a new bid saying only we are working toward a revised proposal for majority employee ownership | |||||
meanwhile in another sign that a new bid is n't imminent it was learned that the ual board held a telephone meeting monday to hear an update on the situation but that a formal board meeting is n't likely to be <unk> until early next week | |||||
in london british airways chairman lord king was quoted in the times as declaring he is not prepared to take my shareholders into a <unk> deal | |||||
observers said it appeared that british air was angered at the way the bid has <unk> into confusion as well as by the banks ' effort to round up financing for what one called a deal that is n't a deal | |||||
the effort to revive the bid was complicated by the <unk> nature of the <unk> buying group | |||||
the pilots were meeting outside chicago yesterday | |||||
but british air which was to have supplied $ N million out of $ N million in equity financing apparently was n't involved in the second proposal and could well reject it even if banks obtain financing | |||||
a group of united 's <unk> employees said in a statement the fact that wolf and other officers were going to line their pockets with literally millions of dollars while <unk> severe pay cuts on the <unk> employees of united is not only <unk> but <unk> | |||||
the machinists also asked for an investigation by the securities and exchange commission into possible <unk> violations in the original bid for ual by mr. davis as well as in the response by ual | |||||
last week just before the bank commitments were due the union asked the u.s. labor department to study whether the bid violated legal standards of fairness governing employee investment funds | |||||
in his statement mr. wolf said we continue to believe our approach is sound and that it is far better for all employees than the alternative of having an outsider own the company with employees paying for it just the same | |||||
mr. wolf has <unk> merger advice from a major wall street securities firm relying instead only on a takeover lawyer peter <unk> of <unk> <unk> slate <unk> & flom | |||||
the huge drop in ual stock prompted one takeover stock trader george <unk> managing partner of <unk> <unk> & co. to deny publicly rumors that his firm was going out of business | |||||
mr. <unk> said that despite losses on ual stock his firm 's health is excellent | |||||
the stock 's decline also has left the ual board in a <unk> | |||||
although it may not be legally obligated to sell the company if the buy-out group ca n't revive its bid it may have to explore alternatives if the buyers come back with a bid much lower than the group 's original $ 300-a-share proposal | |||||
at a meeting sept. N to consider the labor-management bid the board also was informed by its investment adviser first boston corp. of interest expressed by buy-out funds including kohlberg kravis roberts & co. and <unk> little & co. as well as by robert bass morgan stanley 's buy-out fund and pan am corp | |||||
the takeover-stock traders were hoping that mr. davis or one of the other interested parties might <unk> with the situation in disarray or that the board might consider a recapitalization | |||||
meanwhile japanese bankers said they were still <unk> about accepting citicorp 's latest proposal | |||||
macmillan inc. said it plans a public offering of N million shares of its berlitz international inc. unit at $ N to $ N a share | |||||
the offering for the language school unit was announced by robert maxwell chairman and chief executive officer of london-based maxwell communication corp. which owns macmillan | |||||
after the offering is completed macmillan will own about N N of the berlitz common stock outstanding | |||||
five million shares will be offered in the u.s. and N million additional shares will be offered in <unk> international offerings outside the u.s. | |||||
goldman sachs & co. will manage the offering | |||||
macmillan said berlitz intends to pay quarterly dividends on the stock | |||||
the company said it expects to pay the first dividend of N cents a share in the N first quarter | |||||
berlitz will borrow an amount equal to its expected net proceeds from the offerings plus $ N million in connection with a credit agreement with lenders | |||||
the total borrowing will be about $ N million the company said | |||||
proceeds from the borrowings under the credit agreement will be used to pay an $ N million cash dividend to macmillan and to lend the remainder of about $ N million to maxwell communications in connection with a <unk> note | |||||
proceeds from the offering will be used to repay borrowings under the short-term parts of a credit agreement | |||||
berlitz which is based in princeton n.j. provides language instruction and translation services through more than N language centers in N countries | |||||
in the past five years more than N N of its sales have been outside the u.s. | |||||
macmillan has owned berlitz since N | |||||
in the first six months of this year berlitz posted net income of $ N million on sales of $ N million compared with net income of $ N million on sales of $ N million | |||||
right away you notice the following things about a philip glass concert | |||||
it attracts people with funny hair or with no hair in front of me a girl with <unk> <unk> sat <unk> a boy who had <unk> his | |||||
whoever constitute the local left bank come out in force dressed in black along with a <unk> of <unk> who want to be on the cutting edge | |||||
people in glass houses tend to look <unk> | |||||
and if still <unk> at the evening 's end you notice something else the audience at first <unk> and <unk> by the music releases its <unk> feelings in collective <unk> | |||||
currently in the middle of a <unk> <unk> tour as a solo <unk> mr. glass has left behind his <unk> equipment and <unk> in favor of going it alone | |||||
he sits down at the piano and plays | |||||
and plays | |||||
either one likes it or one does n't | |||||
the typical glass audience which is more likely to be composed of music students than their teachers certainly does | |||||
the work though sounds like <unk> for <unk> | |||||
philip glass is the <unk> and his music the new clothes of the <unk> | |||||
his success is easy to understand | |||||
<unk> introducing and explaining his pieces mr. glass looks and sounds more like a <unk> <unk> describing his work than a classical <unk> playing a recital | |||||
the piano <unk> which have been labeled <unk> as <unk> <unk> <unk> cyclical <unk> and <unk> are <unk> <unk> therefore <unk> <unk> <unk> therefore <unk> and <unk> <unk> but <unk> therefore both pretty and <unk> | |||||
it is music for people who want to hear something different but do n't want to work especially hard at the task | |||||
it is <unk> listening for the now generation | |||||
mr. glass has <unk> the famous <unk> <unk> less is more | |||||
his more is always less | |||||
far from being <unk> the music <unk> <unk> us with apparent <unk> not so <unk> <unk> in the <unk> of N time <unk> <unk> and <unk> or <unk> <unk> <unk> | |||||
but the music has its <unk> and mr. glass has constructed his solo program around a move from the simple to the relatively complex | |||||
opening N from <unk> <unk> the audience to the glass technique never <unk> too far from the piano 's center mr. glass works in the two <unk> on either side of middle c and his fingers seldom leave the <unk> | |||||
there is a <unk> musical style here but not a particular performance style | |||||
the music is not especially <unk> indeed it 's hard to imagine a bad performance of it | |||||
nothing <unk> no <unk> no <unk> <unk> problems challenge the performer | |||||
we hear we may think inner voices but they all seem to be saying the same thing | |||||
with planet news music meant to <unk> <unk> of allen <unk> 's wichita <unk> <unk> mr. glass gets going | |||||
his hands sit <unk> apart on the <unk> | |||||
seventh <unk> make you feel as though he may break into a very slow <unk> <unk> | |||||
the <unk> <unk> but there is little <unk> even though his fingers begin to <unk> over more of the <unk> | |||||
contrasts predictably <unk> first the music is loud then it becomes soft then you realize it becomes <unk> again | |||||
the fourth <unk> play an <unk> from <unk> on the beach is like a <unk> but it does n't seem to move much beyond its <unk> ground in three blind mice | |||||
when mr. glass decides to get really fancy he <unk> his hands and hits a <unk> bass note with his right hand | |||||
he does this in at least three of his solo pieces | |||||
you might call it a <unk> or a <unk> <unk> | |||||
in mad rush which came from a commission to write a piece of <unk> length mr. glass <unk> and <unk> confessed that this was no problem for me an a <unk> with a b section several times before the piece ends <unk> | |||||
not only is the typical <unk> <unk> it is also often multiple in its context s | |||||
mad rush began its life as the <unk> to the <unk> lama 's first public address in the u.s. when mr. glass played it on the <unk> at new york 's <unk> of st. john the <unk> | |||||
later it was performed on radio <unk> in germany and then <unk> <unk> took it for one of her dance pieces | |||||
the point is that any piece can be used as background music for virtually anything | |||||
the evening ended with mr. glass 's <unk> another multiple work | |||||
parts N N and N come from the <unk> of <unk> morris 's <unk> film the thin blue line and the two other parts from <unk> music to two separate <unk> of the <unk> story of the same name | |||||
when used as background in this way the music has an appropriate <unk> as when a <unk> phrase a <unk> minor third <unk> the seemingly endless <unk> of reports interviews and <unk> of witnesses in the morris film | |||||
served up as a solo however the music lacks the <unk> provided by a context within another medium | |||||
<unk> of mr. glass may agree with the critic richard <unk> 's sense that the N music in twelve parts is as <unk> and <unk> as the <unk> <unk> | |||||
but while making the obvious point that both <unk> develop variations from themes this comparison <unk> the intensely <unk> nature of mr. glass 's music | |||||
its supposedly <unk> <unk> <unk> a <unk> that makes one <unk> for the <unk> of <unk> <unk> the <unk> radical <unk> of <unk> and <unk> and what in <unk> even seems like <unk> in <unk> | |||||
mr. <unk> is professor of english at southern <unk> university and editor of the southwest review | |||||
honeywell inc. said it hopes to complete shortly the first of two sales of shares in its japanese joint venture <unk> for about $ N million | |||||
the company would n't disclose the buyer of the initial N N stake | |||||
proceeds of the sale expected to be completed next week would be used to repurchase as many as N million shares of honeywell stock the company said | |||||
honeywell said it is negotiating the sale of a second stake in <unk> but indicated it intends to hold at least N N of the joint venture 's stock long term | |||||
a N N stake would allow honeywell to include <unk> earnings in its results | |||||
honeywell previously said it intended to reduce its holding in the japanese concern as part of a restructuring plan which also calls for a reduction of <unk> on weapons sales | |||||
yesterday a spokeswoman said the company was pleased with our progress in that regard and hopes to provide additional details soon | |||||
honeywell said its defense and marine systems group incurred delays in shipping some undisclosed contracts during the third quarter resulting in lower operating profit for that business | |||||
overall honeywell reported earnings of $ N million or $ N a share for the three months ended oct. N compared with a loss of $ N million or N cents a share a year earlier | |||||
the previous period 's results included a $ N million pretax charge related to <unk> contract costs and a $ N million pretax gain on real estate sales | |||||
sales for the latest quarter were flat at $ N billion | |||||
for the nine months honeywell reported earnings of $ N million or $ N a share compared with earnings of $ N million or $ N a share a year earlier | |||||
sales declined slightly to $ N billion | |||||
once again your editorial page <unk> the law to conform to your almost <unk> <unk> | |||||
in an <unk> of little <unk> to his central point about private enforcement suits by environmental groups michael s. <unk> <unk> your readers the clean water act is written upon the <unk> the <unk> rather that nothing but zero risk will do it <unk> a legal standard of zero <unk> <unk> environmental <unk> sept. N | |||||
this statement surely <unk> your editorial viewpoint that environmental protection is generally silly or excessive but it is simply wrong | |||||
the clean water act contains no legal standard of zero <unk> | |||||
it requires that <unk> of <unk> into the waters of the united states be authorized by permits that reflect the <unk> limitations developed under section N | |||||
whatever may be the problems with this system it <unk> reflects zero risk or zero <unk> | |||||
perhaps mr. <unk> was confused by congress 's <unk> statement of the national goal in section N which indeed calls for the elimination of <unk> by N no less | |||||
this <unk> statement was not taken seriously when enacted in N and should not now be confused with the <unk> provisions of the statute | |||||
thus you do the public a great <unk> when mr. <unk> suggests even <unk> that the clean water act prohibits the preparation of a <unk> and water your <unk> readers may be led to believe that nothing but chance or oversight protects them as they <unk> in the night with their <unk> and waters from the <unk> knock of the sierra club at their doors | |||||
robert j. <unk> | |||||
national geographic the <unk> u.s. magazine is attracting more readers than ever and offers the glossy <unk> pages that upscale advertisers love | |||||
so why did advertising pages plunge by almost N N and ad revenue by N N in the first half | |||||
to hear advertisers tell it the magazine just has n't kept up with the times | |||||
despite renewed interest by the public in such topics as the environment and the third world it has n't been able to shake its reputation as a magazine boys like to <unk> through in search of <unk> tribe women | |||||
worse it lagged behind competitors in offering <unk> <unk> from regional editions to discounts for frequent advertisers | |||||
but now the magazine is attempting to fight back with an ambitious plan including a revamped sales strategy and a surprisingly aggressive ad campaign | |||||
advertisers do n't think of the magazine first says joan <unk> who joined in april as national advertising director | |||||
what we want to do is take a more aggressive stance | |||||
people did n't believe we were in tune with the marketplace and in many ways we were n't | |||||
the <unk> magazine has never had to woo advertisers with quite so much <unk> before | |||||
it largely <unk> on its <unk> <unk> N million subscribers in the first half up from N million a year ago an average age of N for readers at the <unk> of their <unk> years loyalty to the tune of an N N average subscription renewal rate | |||||
the magazine had its best year yet in N when it <unk> its centennial and racked up a N N gain in ad pages to N | |||||
but this year when the <unk> surrounding its centennial died so too did some advertiser interest | |||||
the reason ad executives say is that the entire magazine business has been soft and national geographic has some <unk> that make it especially <unk> during a soft market | |||||
perhaps the biggest of those factors is its high ad prices $ N for a <unk> page vs. $ N for the <unk> a comparable publication with a far smaller circulation | |||||
when ad dollars are tight the high page cost is a major <unk> for advertisers who generally want to appear regularly in a publication or not at all | |||||
even though national geographic offers far more readers than does a magazine like <unk> the page costs you an arm and a leg to develop any frequency says harry glass new york media manager for bozell inc | |||||
to combat that problem national geographic like other magazines began offering regional editions allowing advertisers to appear in only a portion of its magazines for example ads can run only in the magazines sent to subscribers in the largest N markets | |||||
but the magazine was slower than its competitors to come up with its regional editions and until last year offered fewer of them than did competitors | |||||
time magazine for example has more than N separate editions going to different regions top management and other groups | |||||
another sticking point for advertisers was national geographic 's tradition of <unk> its ads together usually at the beginning or end of the magazine rather than spreading ads out among its articles as most magazines do | |||||
and national geographic 's <unk> size means extra production costs for advertisers | |||||
but ms. <unk> says the magazine is fighting back | |||||
it now offers N regional editions it very recently began running ads adjacent to articles and it has been <unk> up its sales force | |||||
and it just launched a promotional campaign to tell chief executives marketing directors and media executives just that | |||||
the centerpiece of the promotion is its new ad campaign into which the magazine will pour about $ N mostly in the next few weeks | |||||
the campaign created by <unk> group 's ddb needham agency takes advantage of the <unk> photography that national geographic is known for | |||||
in one ad a photo of the interior of the <unk> in paris is <unk> with the headline the only book more respected than <unk> does n't accept advertising | |||||
another ad pictures a tree <unk> magnified N times with the headline for impact far beyond your size consider our regional editions | |||||
ms. <unk> says she wants the campaign to help attract advertisers in N categories including corporate financial services consumer electronics insurance and food | |||||
her goal to top N ad pages in N up from about N this year | |||||
whether she can meet that ambitious goal is still far from certain | |||||
the ad campaign is meant to <unk> the thought of national geographic she says | |||||
we want it to be a <unk> kind of image | |||||
wcrs plans <unk> sale | |||||
wcrs group hopes to announce perhaps today an agreement to sell the majority of its ad unit to <unk> eurocom a european ad executive said | |||||
wcrs has been in discussions with eurocom for several months | |||||
however when negotiations <unk> down recently wcrs 's chief executive peter scott met in paris with another french firm <unk> <unk> <unk> <unk> or <unk> | |||||
according to the executive <unk> 's involvement prompted renewed <unk> in the <unk> talks and the two agencies were hoping to <unk> out details by today | |||||
executives of the two agencies could n't be reached last night | |||||
ad notes | |||||
new account procter & gamble co. cincinnati awarded the ad accounts for its line of professional <unk> <unk> <unk> and oil products to <unk> <unk> <unk> cincinnati | |||||
billings were n't disclosed | |||||
professional <unk> products are specially made for the <unk> industry | |||||
who 's news stephen <unk> N was named executive vice president deputy creative director at grey advertising new york | |||||
he was executive vice president director of broadcast production | |||||
the commodity futures trading commission plans to restrict dual trading on commodity exchanges a move almost certain to <unk> exchange officials and traders | |||||
the cftc said it will propose the restrictions after the release of a study that shows little economic benefit resulting from dual trading and cites problems associated with the practice | |||||
dual trading gives an exchange trader the right to trade both for his own account and for customers | |||||
the issue exploded this year after a federal bureau of investigation operation led to charges of widespread trading abuses at the chicago board of trade and chicago mercantile exchange | |||||
while not specifically mentioned in the fbi charges dual trading became a focus of attempts to tighten industry regulations | |||||
critics contend that traders were putting buying or selling for their own accounts ahead of other traders ' customer orders | |||||
traders are likely to oppose such restrictions because dual trading provides a way to make money in slower markets where there is a shortage of customer orders | |||||
the exchanges contend that dual trading improves liquidity in the markets because traders can buy or sell even when they do n't have a customer order in hand | |||||
the exchanges say liquidity becomes a severe problem for <unk> traded contracts such as those with a long time remaining before expiration | |||||
the cftc may take those arguments into account by allowing exceptions to its restrictions | |||||
the agency did n't cite specific situations where dual trading might be allowed but smaller exchanges or contracts that need additional liquidity are expected to be among them | |||||
wendy <unk> the agency 's chairman told the senate agriculture committee that she expects the study to be released within two weeks and the rule changes to be completed by <unk> | |||||
the study by the cftc 's division of economic analysis shows that a trade is a trade a member of the study team said | |||||
whether a trade is done on a dual or <unk> basis the member said does n't seem to have much economic impact | |||||
currently most traders on commodity exchanges specialize in trading either for customer accounts which makes them brokers or for their own accounts as <unk> <unk> | |||||
the tests indicate that dual and <unk> traders are similar in terms of the trade executions and liquidity they provide to the market mrs. <unk> told the senate panel | |||||
members of congress have proposed restricting dual trading in bills to <unk> cftc operations | |||||
the house 's bill would prohibit dual trading in markets with daily average volume of N contracts or more <unk> those considered too difficult to track without a sophisticated computer system | |||||
the senate bill would force the cftc to suspend dual trading if an exchange ca n't show that its oversight system can detect <unk> abuses | |||||
so far one test of restricting dual trading has worked well | |||||
the chicago merc banned dual trading in its standard & poor 's 500-stock index futures pit in N | |||||
under the rules traders decide before a session begins whether they will trade for their own account or for customers | |||||
traders who stand on the pit 's top step where most customer orders are executed ca n't trade for themselves | |||||
a merc spokesman said the plan has n't made much difference in liquidity in the pit | |||||
it 's too soon to tell but people do n't seem to be unhappy with it he said | |||||
he said he would n't comment on the cftc plan until the exchange has seen the full proposal | |||||
but at a meeting last week tom <unk> the board of trade 's president told commodity lawyers dual trading is definitely worth saving | |||||
it adds something to the market | |||||
japanese firms push <unk> car <unk> | |||||
japanese luxury-car makers are trying to set strict design standards for their dealerships | |||||
but some dealers are negotiating <unk> terms while others decline to deal at all | |||||
nissan motor co. 's infiniti division likes to insist that every dealer construct and <unk> a building in a japanese style | |||||
specifications include a <unk> <unk> <unk> at the center of each showroom and a <unk> bridge <unk> a stream that flows into the building from outside | |||||
infiniti has it down to the <unk> says jay <unk> a partner at <unk> power & associates an auto research firm | |||||
toyota motor corp. 's lexus division also provides specifications | |||||
but only two-thirds of lexus dealers are <unk> new buildings according to the lexus <unk> | |||||
some are even coming up with their own novel designs | |||||
in louisville ky. for example david peterson has built a lexus dealership with the showroom on the second floor | |||||
yet some dealers have turned down infiniti or lexus <unk> because they were unwilling or unable to meet the design requirements | |||||
lee seidman of cleveland says infiniti was a bear on <unk> but at least let him <unk> an existing building without the stream | |||||
mr. seidman says he turned down a lexus franchise in part because the building was <unk> but very expensive | |||||
to head off arguments infiniti offers dealers cash bonuses and <unk> construction loans | |||||
<unk> device 's <unk> plays back a lesson | |||||
products <unk> have to be first to be winners | |||||
that 's the lesson offered through one case study featured in a design exhibit | |||||
dictaphone corp. was caught off guard in N when its main competitor <unk> office products of japan introduced a <unk> <unk> recorder half the size of standard <unk> devices | |||||
blocked by patent protection from following suit dictaphone decided to go a step further and cut the <unk> in half again down to the length of a <unk> |
@@ -1,36 +0,0 @@ | |||||
## Introduction | |||||
This is the implementation of [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) paper in PyTorch. | |||||
* Dataset is 600k documents extracted from [Yelp 2018](https://www.yelp.com/dataset) customer reviews | |||||
* Use [NLTK](http://www.nltk.org/) and [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) to tokenize documents and sentences | |||||
* Both CPU & GPU support | |||||
* The best accuracy is 71%, reaching the same performance in the paper | |||||
## Requirement | |||||
* python 3.6 | |||||
* pytorch = 0.3.0 | |||||
* numpy | |||||
* gensim | |||||
* nltk | |||||
* coreNLP | |||||
## Parameters | |||||
According to the paper and experiment, I set model parameters: | |||||
|word embedding dimension|GRU hidden size|GRU layer|word/sentence context vector dimension| | |||||
|---|---|---|---| | |||||
|200|50|1|100| | |||||
And the training parameters: | |||||
|Epoch|learning rate|momentum|batch size| | |||||
|---|---|---|---| | |||||
|3|0.01|0.9|64| | |||||
## Run | |||||
1. Prepare dataset. Download the [data set](https://www.yelp.com/dataset), and unzip the custom reviews as a file. Use preprocess.py to transform file into data set foe model input. | |||||
2. Train the model. Word enbedding of train data in 'yelp.word2vec'. The model will trained and autosaved in 'model.dict' | |||||
``` | |||||
python train | |||||
``` | |||||
3. Test the model. | |||||
``` | |||||
python evaluate | |||||
``` |
@@ -1,45 +0,0 @@ | |||||
from model import * | |||||
from train import * | |||||
def evaluate(net, dataset, bactch_size=64, use_cuda=False): | |||||
dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0) | |||||
count = 0 | |||||
if use_cuda: | |||||
net.cuda() | |||||
for i, batch_samples in enumerate(dataloader): | |||||
x, y = batch_samples | |||||
doc_list = [] | |||||
for sample in x: | |||||
doc = [] | |||||
for sent_vec in sample: | |||||
if use_cuda: | |||||
sent_vec = sent_vec.cuda() | |||||
doc.append(Variable(sent_vec, volatile=True)) | |||||
doc_list.append(pack_sequence(doc)) | |||||
if use_cuda: | |||||
y = y.cuda() | |||||
predicts = net(doc_list) | |||||
p, idx = torch.max(predicts, dim=1) | |||||
idx = idx.data | |||||
count += torch.sum(torch.eq(idx, y)) | |||||
return count | |||||
if __name__ == '__main__': | |||||
''' | |||||
Evaluate the performance of models | |||||
''' | |||||
from gensim.models import Word2Vec | |||||
embed_model = Word2Vec.load('yelp.word2vec') | |||||
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | |||||
del embed_model | |||||
net = HAN(input_size=200, output_size=5, | |||||
word_hidden_size=50, word_num_layers=1, word_context_size=100, | |||||
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | |||||
net.load_state_dict(torch.load('models.dict')) | |||||
test_dataset = YelpDocSet('reviews', 199, 4, embedding) | |||||
correct = evaluate(net, test_dataset, True) | |||||
print('accuracy {}'.format(correct / len(test_dataset))) |
@@ -1,50 +0,0 @@ | |||||
'''' | |||||
Tokenize yelp dataset's documents using stanford core nlp | |||||
''' | |||||
import json | |||||
import os | |||||
import pickle | |||||
import nltk | |||||
from nltk.tokenize import stanford | |||||
input_filename = 'review.json' | |||||
# config for stanford core nlp | |||||
os.environ['JAVAHOME'] = 'D:\\java\\bin\\java.exe' | |||||
path_to_jar = 'E:\\College\\fudanNLP\\stanford-corenlp-full-2018-02-27\\stanford-corenlp-3.9.1.jar' | |||||
tokenizer = stanford.CoreNLPTokenizer() | |||||
in_dirname = 'review' | |||||
out_dirname = 'reviews' | |||||
f = open(input_filename, encoding='utf-8') | |||||
samples = [] | |||||
j = 0 | |||||
for i, line in enumerate(f.readlines()): | |||||
review = json.loads(line) | |||||
samples.append((review['stars'], review['text'])) | |||||
if (i + 1) % 5000 == 0: | |||||
print(i) | |||||
pickle.dump(samples, open(in_dirname + '/samples%d.pkl' % j, 'wb')) | |||||
j += 1 | |||||
samples = [] | |||||
pickle.dump(samples, open(in_dirname + '/samples%d.pkl' % j, 'wb')) | |||||
# samples = pickle.load(open(out_dirname + '/samples0.pkl', 'rb')) | |||||
# print(samples[0]) | |||||
for fn in os.listdir(in_dirname): | |||||
print(fn) | |||||
precessed = [] | |||||
for stars, text in pickle.load(open(os.path.join(in_dirname, fn), 'rb')): | |||||
tokens = [] | |||||
sents = nltk.tokenize.sent_tokenize(text) | |||||
for s in sents: | |||||
tokens.append(tokenizer.tokenize(s)) | |||||
precessed.append((stars, tokens)) | |||||
# print(tokens) | |||||
if len(precessed) % 100 == 0: | |||||
print(len(precessed)) | |||||
pickle.dump(precessed, open(os.path.join(out_dirname, fn), 'wb')) |
@@ -1,171 +0,0 @@ | |||||
import os | |||||
import pickle | |||||
import numpy as np | |||||
import torch | |||||
from model import * | |||||
class SentIter: | |||||
def __init__(self, dirname, count): | |||||
self.dirname = dirname | |||||
self.count = int(count) | |||||
def __iter__(self): | |||||
for f in os.listdir(self.dirname)[:self.count]: | |||||
with open(os.path.join(self.dirname, f), 'rb') as f: | |||||
for y, x in pickle.load(f): | |||||
for sent in x: | |||||
yield sent | |||||
def train_word_vec(): | |||||
# load data | |||||
dirname = 'reviews' | |||||
sents = SentIter(dirname, 238) | |||||
# define models and train | |||||
model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5) | |||||
model.build_vocab(sents) | |||||
model.train(sents, total_examples=model.corpus_count, epochs=10) | |||||
model.save('yelp.word2vec') | |||||
print(model.wv.similarity('woman', 'man')) | |||||
print(model.wv.similarity('nice', 'awful')) | |||||
class Embedding_layer: | |||||
def __init__(self, wv, vector_size): | |||||
self.wv = wv | |||||
self.vector_size = vector_size | |||||
def get_vec(self, w): | |||||
try: | |||||
v = self.wv[w] | |||||
except KeyError as e: | |||||
v = np.random.randn(self.vector_size) | |||||
return v | |||||
from torch.utils.data import DataLoader, Dataset | |||||
class YelpDocSet(Dataset): | |||||
def __init__(self, dirname, start_file, num_files, embedding): | |||||
self.dirname = dirname | |||||
self.num_files = num_files | |||||
self._files = os.listdir(dirname)[start_file:start_file + num_files] | |||||
self.embedding = embedding | |||||
self._cache = [(-1, None) for i in range(5)] | |||||
def get_doc(self, n): | |||||
file_id = n // 5000 | |||||
idx = file_id % 5 | |||||
if self._cache[idx][0] != file_id: | |||||
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f: | |||||
self._cache[idx] = (file_id, pickle.load(f)) | |||||
y, x = self._cache[idx][1][n % 5000] | |||||
sents = [] | |||||
for s_list in x: | |||||
sents.append(' '.join(s_list)) | |||||
x = '\n'.join(sents) | |||||
return x, y - 1 | |||||
def __len__(self): | |||||
return len(self._files) * 5000 | |||||
def __getitem__(self, n): | |||||
file_id = n // 5000 | |||||
idx = file_id % 5 | |||||
if self._cache[idx][0] != file_id: | |||||
print('load {} to {}'.format(file_id, idx)) | |||||
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f: | |||||
self._cache[idx] = (file_id, pickle.load(f)) | |||||
y, x = self._cache[idx][1][n % 5000] | |||||
doc = [] | |||||
for sent in x: | |||||
if len(sent) == 0: | |||||
continue | |||||
sent_vec = [] | |||||
for word in sent: | |||||
vec = self.embedding.get_vec(word) | |||||
sent_vec.append(vec.tolist()) | |||||
sent_vec = torch.Tensor(sent_vec) | |||||
doc.append(sent_vec) | |||||
if len(doc) == 0: | |||||
doc = [torch.zeros(1, 200)] | |||||
return doc, y - 1 | |||||
def collate(iterable): | |||||
y_list = [] | |||||
x_list = [] | |||||
for x, y in iterable: | |||||
y_list.append(y) | |||||
x_list.append(x) | |||||
return x_list, torch.LongTensor(y_list) | |||||
def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): | |||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |||||
criterion = nn.NLLLoss() | |||||
dataloader = DataLoader(dataset, | |||||
batch_size=batch_size, | |||||
collate_fn=collate, | |||||
num_workers=0) | |||||
running_loss = 0.0 | |||||
if use_cuda: | |||||
net.cuda() | |||||
print('start training') | |||||
for epoch in range(num_epoch): | |||||
for i, batch_samples in enumerate(dataloader): | |||||
x, y = batch_samples | |||||
doc_list = [] | |||||
for sample in x: | |||||
doc = [] | |||||
for sent_vec in sample: | |||||
if use_cuda: | |||||
sent_vec = sent_vec.cuda() | |||||
doc.append(Variable(sent_vec)) | |||||
doc_list.append(pack_sequence(doc)) | |||||
if use_cuda: | |||||
y = y.cuda() | |||||
y = Variable(y) | |||||
predict = net(doc_list) | |||||
loss = criterion(predict, y) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
running_loss += loss.data[0] | |||||
if i % print_size == print_size - 1: | |||||
print('{}, {}'.format(i + 1, running_loss / print_size)) | |||||
running_loss = 0.0 | |||||
torch.save(net.state_dict(), 'models.dict') | |||||
torch.save(net.state_dict(), 'models.dict') | |||||
if __name__ == '__main__': | |||||
''' | |||||
Train process | |||||
''' | |||||
from gensim.models import Word2Vec | |||||
from gensim import models | |||||
train_word_vec() | |||||
embed_model = Word2Vec.load('yelp.word2vec') | |||||
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | |||||
del embed_model | |||||
start_file = 0 | |||||
dataset = YelpDocSet('reviews', start_file, 120 - start_file, embedding) | |||||
print('training data size {}'.format(len(dataset))) | |||||
net = HAN(input_size=200, output_size=5, | |||||
word_hidden_size=50, word_num_layers=1, word_context_size=100, | |||||
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | |||||
try: | |||||
net.load_state_dict(torch.load('models.dict')) | |||||
print("last time trained models has loaded") | |||||
except Exception: | |||||
print("cannot load models, train the inital models") | |||||
train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True) |
@@ -2,28 +2,30 @@ | |||||
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | 这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | ||||
复现的模型有: | 复现的模型有: | ||||
- [Star-Transformer](Star_transformer/) | |||||
- [Star-Transformer](Star_transformer) | |||||
- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239) | |||||
- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12) | |||||
- ... | - ... | ||||
# 任务复现 | # 任务复现 | ||||
## Text Classification (文本分类) | ## Text Classification (文本分类) | ||||
- still in progress | |||||
- [Text Classification 文本分类任务复现](text_classification) | |||||
## Matching (自然语言推理/句子匹配) | ## Matching (自然语言推理/句子匹配) | ||||
- [Matching 任务复现](matching/) | |||||
- [Matching 任务复现](matching) | |||||
## Sequence Labeling (序列标注) | ## Sequence Labeling (序列标注) | ||||
- still in progress | |||||
- [NER](seqence_labelling/ner) | |||||
## Coreference resolution (指代消解) | |||||
- still in progress | |||||
## Coreference Resolution (共指消解) | |||||
- [Coreference Resolution 共指消解任务复现](coreference_resolution) | |||||
## Summarization (摘要) | ## Summarization (摘要) | ||||
- still in progress | |||||
- [Summerization 摘要任务复现](Summarization) | |||||
## ... | ## ... |
@@ -6,29 +6,6 @@ paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |Pos Tagging|CTB 9.0|-|ACC 92.31| | ||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |Pos Tagging|CONLL 2012|-|ACC 96.51| | ||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |Named Entity Recognition|CONLL 2012|-|F1 85.66| | ||||
|Text Classification|SST|-|49.18| | |||||
|Text Classification|SST|-|51.2| | |||||
|Natural Language Inference|SNLI|-|83.76| | |Natural Language Inference|SNLI|-|83.76| | ||||
## Usage | |||||
``` python | |||||
# for sequence labeling(ner, pos tagging, etc) | |||||
from fastNLP.models.star_transformer import STSeqLabel | |||||
model = STSeqLabel( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for sequence classification | |||||
from fastNLP.models.star_transformer import STSeqCls | |||||
model = STSeqCls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for natural language inference | |||||
from fastNLP.models.star_transformer import STNLICls | |||||
model = STNLICls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
``` |
@@ -2,7 +2,7 @@ import torch | |||||
import json | import json | ||||
import os | import os | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader | |||||
from fastNLP.io.data_loader import ConllLoader, SSTLoader, SNLILoader | |||||
from fastNLP.core import Const as C | from fastNLP.core import Const as C | ||||
import numpy as np | import numpy as np | ||||
@@ -50,13 +50,15 @@ def load_sst(path, files): | |||||
for sub in [True, False, False]] | for sub in [True, False, False]] | ||||
ds_list = [loader.load(os.path.join(path, fn)) | ds_list = [loader.load(os.path.join(path, fn)) | ||||
for fn, loader in zip(files, loaders)] | for fn, loader in zip(files, loaders)] | ||||
word_v = Vocabulary(min_freq=2) | |||||
word_v = Vocabulary(min_freq=0) | |||||
tag_v = Vocabulary(unknown=None, padding=None) | tag_v = Vocabulary(unknown=None, padding=None) | ||||
for ds in ds_list: | for ds in ds_list: | ||||
ds.apply(lambda x: [w.lower() | ds.apply(lambda x: [w.lower() | ||||
for w in x['words']], new_field_name='words') | for w in x['words']], new_field_name='words') | ||||
ds_list[0].drop(lambda x: len(x['words']) < 3) | |||||
#ds_list[0].drop(lambda x: len(x['words']) < 3) | |||||
update_v(word_v, ds_list[0], 'words') | update_v(word_v, ds_list[0], 'words') | ||||
update_v(word_v, ds_list[1], 'words') | |||||
update_v(word_v, ds_list[2], 'words') | |||||
ds_list[0].apply(lambda x: tag_v.add_word( | ds_list[0].apply(lambda x: tag_v.add_word( | ||||
x['target']), new_field_name=None) | x['target']), new_field_name=None) | ||||
@@ -151,7 +153,10 @@ class EmbedLoader: | |||||
# some words from vocab are missing in pre-trained embedding | # some words from vocab are missing in pre-trained embedding | ||||
# we normally sample each dimension | # we normally sample each dimension | ||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | vocab_embed = embedding_matrix[np.where(hit_flags)] | ||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
#sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
# size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||||
sampled_vectors = np.random.uniform(-0.01, 0.01, | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | size=(len(vocab) - np.sum(hit_flags), emb_dim)) | ||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | ||||
return embedding_matrix | return embedding_matrix |
@@ -1,5 +1,5 @@ | |||||
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | #python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | ||||
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | #python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | ||||
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & | |||||
python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log & | |||||
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | #python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | ||||
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & | |||||
#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & |
@@ -1,4 +1,6 @@ | |||||
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | ||||
seed = set_rng_seeds(15360) | |||||
print('RNG SEED {}'.format(seed)) | |||||
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch | import torch | ||||
@@ -7,8 +9,9 @@ import fastNLP as FN | |||||
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
import sys | import sys | ||||
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||||
#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||||
import os | |||||
pre_dir = os.path.join(os.environ['HOME'], 'workdir/datasets/') | |||||
g_model_select = { | g_model_select = { | ||||
'pos': STSeqLabel, | 'pos': STSeqLabel, | ||||
@@ -17,8 +20,8 @@ g_model_select = { | |||||
'nli': STNLICls, | 'nli': STNLICls, | ||||
} | } | ||||
g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', | |||||
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} | |||||
g_emb_file_path = {'en': pre_dir + 'word_vector/glove.840B.300d.txt', | |||||
'zh': pre_dir + 'cc.zh.300.vec'} | |||||
g_args = None | g_args = None | ||||
g_model_cfg = None | g_model_cfg = None | ||||
@@ -53,7 +56,7 @@ def get_conll2012_ner(): | |||||
def get_sst(): | def get_sst(): | ||||
path = '/remote-home/yfshao/workdir/datasets/SST' | |||||
path = pre_dir + 'SST' | |||||
files = ['train.txt', 'dev.txt', 'test.txt'] | files = ['train.txt', 'dev.txt', 'test.txt'] | ||||
return load_sst(path, files) | return load_sst(path, files) | ||||
@@ -94,6 +97,7 @@ class MyCallback(FN.core.callback.Callback): | |||||
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | ||||
def on_step_end(self): | def on_step_end(self): | ||||
return | |||||
warm_steps = 6000 | warm_steps = 6000 | ||||
# learning rate warm-up & decay | # learning rate warm-up & decay | ||||
if self.step <= warm_steps: | if self.step <= warm_steps: | ||||
@@ -108,12 +112,11 @@ class MyCallback(FN.core.callback.Callback): | |||||
def train(): | def train(): | ||||
seed = set_rng_seeds(1234) | |||||
print('RNG SEED {}'.format(seed)) | |||||
print('loading data') | print('loading data') | ||||
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | ||||
g_args.ds, g_args.task)]() | g_args.ds, g_args.task)]() | ||||
print(ds_list[0][:2]) | print(ds_list[0][:2]) | ||||
print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2])) | |||||
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | ||||
g_model_cfg['num_cls'] = len(tag_v) | g_model_cfg['num_cls'] = len(tag_v) | ||||
print(g_model_cfg) | print(g_model_cfg) | ||||
@@ -123,11 +126,14 @@ def train(): | |||||
def init_model(model): | def init_model(model): | ||||
for p in model.parameters(): | for p in model.parameters(): | ||||
if p.size(0) != len(word_v): | if p.size(0) != len(word_v): | ||||
nn.init.normal_(p, 0.0, 0.05) | |||||
if len(p.size())<2: | |||||
nn.init.constant_(p, 0.0) | |||||
else: | |||||
nn.init.normal_(p, 0.0, 0.05) | |||||
init_model(model) | init_model(model) | ||||
train_data = ds_list[0] | train_data = ds_list[0] | ||||
dev_data = ds_list[2] | |||||
test_data = ds_list[1] | |||||
dev_data = ds_list[1] | |||||
test_data = ds_list[2] | |||||
print(tag_v.word2idx) | print(tag_v.word2idx) | ||||
if g_args.task in ['pos', 'ner']: | if g_args.task in ['pos', 'ner']: | ||||
@@ -145,19 +151,31 @@ def train(): | |||||
} | } | ||||
metric_key, metric = metrics[g_args.task] | metric_key, metric = metrics[g_args.task] | ||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||
ex_param = [x for x in model.parameters( | |||||
) if x.requires_grad and x.size(0) != len(word_v)] | |||||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||||
device=device, callbacks=[MyCallback()]) | |||||
trainer.train() | |||||
params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)] | |||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | |||||
print([n for n,p in params]) | |||||
optim_cfg = [ | |||||
#{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||||
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay}, | |||||
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay} | |||||
] | |||||
print(model) | |||||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=loss, metrics=metric, metric_key=metric_key, | |||||
optimizer=torch.optim.Adam(optim_cfg), | |||||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000, | |||||
device=device, | |||||
use_tqdm=False, prefetch=False, | |||||
save_path=g_args.log, | |||||
sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN), | |||||
callbacks=[MyCallback()]) | |||||
print(trainer.train()) | |||||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | tester = FN.Tester(data=test_data, model=model, metrics=metric, | ||||
batch_size=128, device=device) | batch_size=128, device=device) | ||||
tester.test() | |||||
print(tester.test()) | |||||
def test(): | def test(): | ||||
@@ -195,12 +213,12 @@ def main(): | |||||
'init_embed': (None, 300), | 'init_embed': (None, 300), | ||||
'num_cls': None, | 'num_cls': None, | ||||
'hidden_size': g_args.hidden, | 'hidden_size': g_args.hidden, | ||||
'num_layers': 4, | |||||
'num_layers': 2, | |||||
'num_head': g_args.nhead, | 'num_head': g_args.nhead, | ||||
'head_dim': g_args.hdim, | 'head_dim': g_args.hdim, | ||||
'max_len': MAX_LEN, | 'max_len': MAX_LEN, | ||||
'cls_hidden_size': 600, | |||||
'emb_dropout': 0.3, | |||||
'cls_hidden_size': 200, | |||||
'emb_dropout': g_args.drop, | |||||
'dropout': g_args.drop, | 'dropout': g_args.drop, | ||||
} | } | ||||
run_select[g_args.mode.lower()]() | run_select[g_args.mode.lower()]() | ||||
@@ -0,0 +1,12 @@ | |||||
{ | |||||
"n_layers": 16, | |||||
"layer_sum": false, | |||||
"layer_cat": false, | |||||
"lstm_hidden_size": 300, | |||||
"ffn_inner_hidden_size": 2048, | |||||
"n_head": 6, | |||||
"recurrent_dropout_prob": 0.1, | |||||
"atten_dropout_prob": 0.1, | |||||
"ffn_dropout_prob": 0.1, | |||||
"fix_mask": true | |||||
} |
@@ -0,0 +1,3 @@ | |||||
{ | |||||
} |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"n_layers": 12, | |||||
"hidden_size": 512, | |||||
"ffn_inner_hidden_size": 2048, | |||||
"n_head": 8, | |||||
"recurrent_dropout_prob": 0.1, | |||||
"atten_dropout_prob": 0.1, | |||||
"ffn_dropout_prob": 0.1 | |||||
} |
@@ -0,0 +1,188 @@ | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.base_loader import DataBundle | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.core.const import Const | |||||
from tools.logger import * | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class SummarizationLoader(JsonLoader): | |||||
""" | |||||
读取summarization数据集,读取的DataSet包含fields:: | |||||
text: list(str),document | |||||
summary: list(str), summary | |||||
text_wd: list(list(str)),tokenized document | |||||
summary_wd: list(list(str)), tokenized summary | |||||
labels: list(int), | |||||
flatten_label: list(int), 0 or 1, flatten labels | |||||
domain: str, optional | |||||
tag: list(str), optional | |||||
数据来源: CNN_DailyMail Newsroom DUC | |||||
""" | |||||
def __init__(self): | |||||
super(SummarizationLoader, self).__init__() | |||||
def _load(self, path): | |||||
ds = super(SummarizationLoader, self)._load(path) | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||||
return ds | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True): | |||||
""" | |||||
:param paths: dict path for each dataset | |||||
:param vocab_size: int max_size for vocab | |||||
:param vocab_path: str vocab path | |||||
:param sent_max_len: int max token number of the sentence | |||||
:param doc_max_timesteps: int max sentence number of the document | |||||
:param domain: bool build vocab for publication, use 'X' for unknown | |||||
:param tag: bool build vocab for tag, use 'X' for unknown | |||||
:param load_vocab: bool build vocab (False) or load vocab (True) | |||||
:return: DataBundle | |||||
datasets: dict keys correspond to the paths dict | |||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||||
embeddings: optional | |||||
""" | |||||
def _pad_sent(text_wd): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
datasets = {} | |||||
train_ds = None | |||||
for key, value in paths.items(): | |||||
ds = self.load(value) | |||||
# pad sent | |||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||||
# pad document | |||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||||
# rename field | |||||
ds.rename_field("pad_text", Const.INPUT) | |||||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||||
ds.rename_field("pad_label", Const.TARGET) | |||||
# set input and target | |||||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
datasets[key] = ds | |||||
if "train" in key: | |||||
train_ds = datasets[key] | |||||
vocab_dict = {} | |||||
if load_vocab == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | |||||
if train_ds == None: | |||||
raise ValueError("Lack train file to build vocabulary!") | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||||
vocab_dict["vocab"] = vocabs | |||||
else: | |||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||||
word_list = [] | |||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
vocab_dict["vocab"] = vocabs | |||||
if domain == True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(train_ds, field_name="publication") | |||||
vocab_dict["domain"] = domaindict | |||||
if tag == True: | |||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||||
tagdict.from_dataset(train_ds, field_name="tag") | |||||
vocab_dict["tag"] = tagdict | |||||
for ds in datasets.values(): | |||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||||
@@ -0,0 +1,136 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.init as init | |||||
import torch.nn.functional as F | |||||
from torch.autograd import Variable | |||||
from torch.distributions import Bernoulli | |||||
class DeepLSTM(nn.Module): | |||||
def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True): | |||||
super(DeepLSTM, self).__init__() | |||||
self.fix_mask = fix_mask | |||||
self.use_cuda = use_cuda | |||||
self.input_size = input_size | |||||
self.num_layers = num_layers | |||||
self.hidden_size = hidden_size | |||||
self.recurrent_dropout = recurrent_dropout | |||||
self.lstms = nn.ModuleList([None] * self.num_layers) | |||||
self.highway_gate_input = nn.ModuleList([None] * self.num_layers) | |||||
self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers) | |||||
self.highway_linear_input = nn.ModuleList([None] * self.num_layers) | |||||
# self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size)) | |||||
# init.xavier_normal_(self._input_w) | |||||
for l in range(self.num_layers): | |||||
input_dim = input_size if l == 0 else hidden_size | |||||
self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size) | |||||
self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size) | |||||
self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False) | |||||
# logger.info("[INFO] Initing W for LSTM .......") | |||||
for l in range(self.num_layers): | |||||
if use_orthnormal_init: | |||||
# logger.info("[INFO] Initing W using orthnormal init .......") | |||||
init.orthogonal_(self.lstms[l].weight_ih) | |||||
init.orthogonal_(self.lstms[l].weight_hh) | |||||
init.orthogonal_(self.highway_gate_input[l].weight.data) | |||||
init.orthogonal_(self.highway_gate_state[l].weight.data) | |||||
init.orthogonal_(self.highway_linear_input[l].weight.data) | |||||
else: | |||||
# logger.info("[INFO] Initing W using xavier_normal .......") | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
def init_hidden(self, batch_size, hidden_size): | |||||
# the first is the hidden h | |||||
# the second is the cell c | |||||
if self.use_cuda: | |||||
return (torch.zeros(batch_size, hidden_size).cuda(), | |||||
torch.zeros(batch_size, hidden_size).cuda()) | |||||
else: | |||||
return (torch.zeros(batch_size, hidden_size), | |||||
torch.zeros(batch_size, hidden_size)) | |||||
def forward(self, inputs, input_masks, Train): | |||||
''' | |||||
inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) | |||||
input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) | |||||
''' | |||||
batch_size, seq_len = inputs[0].size(1), inputs[0].size(0) | |||||
# inputs[0] = torch.matmul(inputs[0], self._input_w) | |||||
# input_masks[0] = input_masks[0].unsqueeze(-1).expand(seq_len, batch_size, self.hidden_size) | |||||
self.inputs = inputs | |||||
self.input_masks = input_masks | |||||
if self.fix_mask: | |||||
self.output_dropout_layers = [None] * self.num_layers | |||||
for l in range(self.num_layers): | |||||
binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout | |||||
# This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. | |||||
# from allennlp.nn.util.py | |||||
self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout) | |||||
if self.use_cuda: | |||||
self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda() | |||||
for l in range(self.num_layers): | |||||
h, c = self.init_hidden(batch_size, self.hidden_size) | |||||
outputs_list = [] | |||||
for t in range(len(self.inputs[l])): | |||||
x = self.inputs[l][t] | |||||
m = self.input_masks[l][t].float() | |||||
h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size] | |||||
r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h)) | |||||
lx = self.highway_linear_input[l](x) # [batch, hidden_size] | |||||
h_temp = r * h_temp + (1 - r) * lx | |||||
if Train: | |||||
if self.fix_mask: | |||||
h_temp = self.output_dropout_layers[l] * h_temp | |||||
else: | |||||
h_temp = F.dropout(h_temp, p=self.recurrent_dropout) | |||||
h = m * h_temp + (1 - m) * h | |||||
c = m * c_temp + (1 - m) * c | |||||
outputs_list.append(h) | |||||
outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size] | |||||
self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size] | |||||
self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0) | |||||
self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size] | |||||
# flip -2 layer | |||||
# self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0) | |||||
# concat last two layer | |||||
# self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1) | |||||
self.output_state = self.output_state[-1].transpose(0, 1) | |||||
assert self.output_state.size() == (batch_size, seq_len, self.hidden_size) | |||||
return self.output_state | |||||
@staticmethod | |||||
def flip(x, dim): | |||||
xsize = x.size() | |||||
dim = x.dim() + dim if dim < 0 else dim | |||||
x = x.contiguous() | |||||
x = x.view(-1, *xsize[dim:]).contiguous() | |||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, | |||||
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] | |||||
return x.view(xsize) |
@@ -0,0 +1,566 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.nn.init as init | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
# from tools.logger import * | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
WORD_PAD = "[PAD]" | |||||
class Encoder(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: | |||||
word_emb_dim: word embedding dimension | |||||
sent_max_len: max token number in the sentence | |||||
output_channel: output channel for cnn | |||||
min_kernel_size: min kernel size for cnn | |||||
max_kernel_size: max kernel size for cnn | |||||
word_embedding: bool, use word embedding or not | |||||
embedding_path: word embedding path | |||||
embed_train: bool, whether to train word embedding | |||||
cuda: bool, use cuda or not | |||||
:param vocab: FastNLP.Vocabulary | |||||
""" | |||||
super(Encoder, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self.embed = embed | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
print("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
print("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len] | |||||
batch_size, N, _ = input.size() | |||||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||||
input_sent_len = ((input!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class DomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(DomainEncoder, self).__init__(hps, vocab) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
sent_embedding = super().forward(input) | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class MultiDomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||||
self.domain_size = domaindict.size() | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size, domain_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
# logger.info(domain[:5, :]) | |||||
sent_embedding = super().forward(input) | |||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, hps): | |||||
super(BertEncoder, self).__init__() | |||||
from pytorch_pretrained_bert.modeling import BertModel | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
self._cuda = hps.cuda | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||||
self._bert.eval() | |||||
for p in self._bert.parameters(): | |||||
p.requires_grad = False | |||||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
for i in range(batch_size): | |||||
tokens_id = inputs[i] | |||||
input_mask = input_masks[i] | |||||
sent_len = enc_sent_len[i] | |||||
input_ids = tokens_id.unsqueeze(0) | |||||
input_mask = input_mask.unsqueeze(0) | |||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 1 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 511: | |||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class BertTagEncoder(BertEncoder): | |||||
def __init__(self, hps, domaindict): | |||||
super(BertTagEncoder, self).__init__(hps) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||||
batch_size, N = enc_sent_len.size() | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class ELMoEndoer(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||||
batch_size, N, seq_len, _ = input.size() | |||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
# logger.debug(input_sent_len.view(batch_size, -1)) | |||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||||
enc_embed_input = self.embed_proj(enc_embed_input) | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in input_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class ELMoEndoer2(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer2, self).__init__() | |||||
self._hps = hps | |||||
self._cuda = hps.cuda | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
sent_embedding: [batch, N, D] | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||||
# print("END elmo") | |||||
for i in range(batch_size): | |||||
sent_len = enc_sent_len[i] # [1, N] | |||||
out = elmo_output[i] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 0 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 512: | |||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding |
@@ -0,0 +1,103 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import * | |||||
from torch.distributions import * | |||||
from .Encoder import Encoder | |||||
from .DeepLSTM import DeepLSTM | |||||
from transformer.SubLayers import MultiHeadAttention,PositionwiseFeedForward | |||||
class SummarizationModel(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: hyperparameters for the model | |||||
:param vocab: vocab object | |||||
""" | |||||
super(SummarizationModel, self).__init__() | |||||
self._hps = hps | |||||
# sentence encoder | |||||
self.encoder = Encoder(hps, embed) | |||||
# Multi-layer highway lstm | |||||
self.num_layers = hps.n_layers | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.lstm_hidden_size = hps.lstm_hidden_size | |||||
self.recurrent_dropout = hps.recurrent_dropout_prob | |||||
self.deep_lstm = DeepLSTM(self.sent_embedding_size, self.lstm_hidden_size, self.num_layers, self.recurrent_dropout, | |||||
hps.use_orthnormal_init, hps.fix_mask, hps.cuda) | |||||
# Multi-head attention | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.lstm_hidden_size / hps.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.slf_attn = MultiHeadAttention(hps.n_head, self.lstm_hidden_size , self.d_k, self.d_v, dropout=hps.atten_dropout_prob) | |||||
self.pos_ffn = PositionwiseFeedForward(self.d_v, self.d_inner, dropout = hps.ffn_dropout_prob) | |||||
self.wh = nn.Linear(self.d_v, 2) | |||||
def forward(self, input, input_len, Train): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], word idx long tensor | |||||
:param input_len: [batch_size, N], 1 for sentence and 0 for padding | |||||
:param Train: True for train and False for eval and test | |||||
:param return_atten: True or False to return multi-head attention output self.output_slf_attn | |||||
:return: | |||||
p_sent: [batch_size, N, 2] | |||||
output_slf_attn: (option) [n_head, batch_size, N, N] | |||||
""" | |||||
# -- Sentence Encoder | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
# -- Multi-layer highway lstm | |||||
input_len = input_len.float() # [batch, N] | |||||
self.inputs = [None] * (self.num_layers + 1) | |||||
self.input_masks = [None] * (self.num_layers + 1) | |||||
self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] | |||||
self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) | |||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
slf_attn_mask = input_len.eq(0.0) # [batch, N], 1 for padding | |||||
slf_attn_mask = slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
# -- Multi-head attention | |||||
self.atten_output, self.output_slf_attn = self.slf_attn(self.lstm_output_state, self.lstm_output_state, self.lstm_output_state, mask=slf_attn_mask) | |||||
self.atten_output *= input_len.unsqueeze(2) # [batch_size, N, lstm_hidden_size = (n_head * d_v)] | |||||
self.multi_atten_output = self.atten_output.view(batch_size, N, self.n_head, self.d_v) # [batch_size, N, n_head, d_v] | |||||
self.multi_atten_context = self.multi_atten_output[:, :, 0::2, :].sum(2) - self.multi_atten_output[:, :, 1::2, :].sum(2) # [batch_size, N, d_v] | |||||
# -- Position-wise Feed-Forward Networks | |||||
self.output_state = self.pos_ffn(self.multi_atten_context) | |||||
self.output_state = self.output_state * input_len.unsqueeze(2) # [batch_size, N, d_v] | |||||
p_sent = self.wh(self.output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output.masked_fill(input_len.eq(0), 0) | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} |
@@ -0,0 +1,55 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.losses import LossBase | |||||
from tools.logger import * | |||||
class MyCrossEntropyLoss(LossBase): | |||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.padding_idx = padding_idx | |||||
self.reduce = reduce | |||||
def get_loss(self, pred, target, mask): | |||||
""" | |||||
:param pred: [batch, N, 2] | |||||
:param target: [batch, N] | |||||
:param input_mask: [batch, N] | |||||
:return: | |||||
""" | |||||
# logger.debug(pred[0:5, :, :]) | |||||
# logger.debug(target[0:5, :]) | |||||
batch, N, _ = pred.size() | |||||
pred = pred.view(-1, 2) | |||||
target = target.view(-1) | |||||
loss = F.cross_entropy(input=pred, target=target, | |||||
ignore_index=self.padding_idx, reduction=self.reduce) | |||||
loss = loss.view(batch, -1) | |||||
loss = loss.masked_fill(mask.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
return loss | |||||
@@ -0,0 +1,171 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import division | |||||
import torch | |||||
from rouge import Rouge | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.metrics import MetricBase | |||||
from tools.logger import * | |||||
from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | |||||
class LabelFMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
self.match = 0.0 | |||||
self.pred = 0.0 | |||||
self.true = 0.0 | |||||
self.match_true = 0.0 | |||||
self.total = 0.0 | |||||
def evaluate(self, pred, target): | |||||
""" | |||||
:param pred: [batch, N] int | |||||
:param target: [batch, N] int | |||||
:return: | |||||
""" | |||||
target = target.data | |||||
pred = pred.data | |||||
# logger.debug(pred.size()) | |||||
# logger.debug(pred[:5,:]) | |||||
batch, N = pred.size() | |||||
self.pred += pred.sum() | |||||
self.true += target.sum() | |||||
self.match += (pred == target).sum() | |||||
self.match_true += ((pred == target) & (pred == 1)).sum() | |||||
self.total += batch * N | |||||
def get_metric(self, reset=True): | |||||
self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total | |||||
logger.debug((self.match,self.pred, self.true, self.match_true, self.total)) | |||||
try: | |||||
accu = self.match / self.total | |||||
precision = self.match_true / self.pred | |||||
recall = self.match_true / self.true | |||||
F = 2 * precision * recall / (precision + recall) | |||||
except ZeroDivisionError: | |||||
F = 0.0 | |||||
logger.error("[Error] float division by zero") | |||||
if reset: | |||||
self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0 | |||||
ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()} | |||||
logger.info(ret) | |||||
return ret | |||||
class RougeMetric(MetricBase): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__() | |||||
self._hps = hps | |||||
self._init_param_map(pred=pred, text=text, summary=refer) | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
def evaluate(self, pred, text, summary): | |||||
""" | |||||
:param prediction: [batch, N] | |||||
:param text: [batch, N] | |||||
:param summary: [batch, N] | |||||
:return: | |||||
""" | |||||
batch_size, N = pred.size() | |||||
for j in range(batch_size): | |||||
original_article_sents = text[j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(summary[j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if | |||||
pred[j][id] == 1 and id < sent_max_number) | |||||
if sent_max_number < self._hps.m and len(hyps) <= 1: | |||||
print("sent_max_number is too short %d, Skip!", sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
self.hyps.append(hyps) | |||||
self.refers.append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("pred:") | |||||
logger.debug(pred[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug(refer) | |||||
continue | |||||
def get_metric(self, reset=True): | |||||
pass | |||||
class FastRougeMetric(RougeMetric): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__(hps, pred, text, refer) | |||||
def get_metric(self, reset=True): | |||||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||||
if len(self.hyps) == 0 or len(self.refers) == 0 : | |||||
logger.error("During testing, no hyps or refers is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(self.hyps, self.refers, avg=True) | |||||
if reset: | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
logger.info(scores_all) | |||||
return scores_all | |||||
class PyRougeMetric(RougeMetric): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__(hps, pred, text, refer) | |||||
def get_metric(self, reset=True): | |||||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||||
if len(self.hyps) == 0 or len(self.refers) == 0: | |||||
logger.error("During testing, no hyps or refers is selected!") | |||||
return | |||||
if isinstance(self.refers[0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = pyrouge_score_all_multi(self.hyps, self.refers) | |||||
else: | |||||
scores_all = pyrouge_score_all(self.hyps, self.refers) | |||||
if reset: | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
logger.info(scores_all) | |||||
return scores_all | |||||
@@ -0,0 +1,143 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .Encoder import Encoder | |||||
# from tools.Encoder import Encoder | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
from tools.logger import * | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from transformer.Layers import EncoderLayer | |||||
class TransformerModel(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: | |||||
min_kernel_size: min kernel size for cnn encoder | |||||
max_kernel_size: max kernel size for cnn encoder | |||||
output_channel: output_channel number for cnn encoder | |||||
hidden_size: hidden size for transformer | |||||
n_layers: transfromer encoder layer | |||||
n_head: multi head attention for transformer | |||||
ffn_inner_hidden_size: FFN hiddens size | |||||
atten_dropout_prob: dropout size | |||||
doc_max_timesteps: max sentence number of the document | |||||
:param vocab: | |||||
""" | |||||
super(TransformerModel, self).__init__() | |||||
self._hps = hps | |||||
self.encoder = Encoder(hps, embed) | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.hidden_size = hps.hidden_size | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.num_layers = hps.n_layers | |||||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v, | |||||
dropout=hps.atten_dropout_prob) | |||||
for _ in range(self.num_layers)]) | |||||
self.wh = nn.Linear(self.hidden_size, 2) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
:param input: [batch_size, N, seq_len] | |||||
:param input_len: [batch_size, N] | |||||
:return: | |||||
""" | |||||
# Sentence Encoder | |||||
input = words | |||||
input_len = seq_len | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
input_len = input_len.float() # [batch, N] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||||
sent_pos = torch.Tensor( | |||||
[np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||||
enc_output_state = self.projection(self.sent_embedding) | |||||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||||
enc_input_list = [] | |||||
for enc_layer in self.layer_stack: | |||||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||||
# enc_slf_attn = [n_head * batch_size, N, N] | |||||
enc_input, enc_slf_atten = enc_layer(enc_input, non_pad_mask=self.non_pad_mask, | |||||
slf_attn_mask=self.slf_attn_mask) | |||||
enc_input_list += [enc_input] | |||||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||||
self.dec_output_state = self.dec_output_state.sum(0) | |||||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output.masked_fill(input_len.eq(0), 0) | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
# logger.debug(((p_sent.size(), prediction.size(), idx.size()))) | |||||
return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} | |||||
@@ -0,0 +1,138 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .Encoder import Encoder | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
class TransformerModel(nn.Module): | |||||
def __init__(self, hps, vocab): | |||||
""" | |||||
:param hps: | |||||
min_kernel_size: min kernel size for cnn encoder | |||||
max_kernel_size: max kernel size for cnn encoder | |||||
output_channel: output_channel number for cnn encoder | |||||
hidden_size: hidden size for transformer | |||||
n_layers: transfromer encoder layer | |||||
n_head: multi head attention for transformer | |||||
ffn_inner_hidden_size: FFN hiddens size | |||||
atten_dropout_prob: dropout size | |||||
doc_max_timesteps: max sentence number of the document | |||||
:param vocab: | |||||
""" | |||||
super(TransformerModel, self).__init__() | |||||
self._hps = hps | |||||
self._vocab = vocab | |||||
self.encoder = Encoder(hps, vocab) | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.hidden_size = hps.hidden_size | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.num_layers = hps.n_layers | |||||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||||
for _ in range(self.num_layers)]) | |||||
self.wh = nn.Linear(self.hidden_size, 2) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
:param input: [batch_size, N, seq_len] | |||||
:param input_len: [batch_size, N] | |||||
:param return_atten: bool | |||||
:return: | |||||
""" | |||||
# Sentence Encoder | |||||
input = words | |||||
input_len = seq_len | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
input_len = input_len.float() # [batch, N] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||||
sent_pos = torch.Tensor([np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||||
enc_output_state = self.projection(self.sent_embedding) | |||||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||||
enc_input_list = [] | |||||
for enc_layer in self.layer_stack: | |||||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||||
# enc_slf_attn = [n_head * batch_size, N, N] | |||||
enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||||
enc_input_list += [enc_input] | |||||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||||
self.dec_output_state = self.dec_output_state.sum(0) | |||||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output * input_len.float() | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
# print((p_sent.size(), prediction.size(), idx.size())) | |||||
# [batch, N, 2], [batch, N], [batch, hps.m] | |||||
return {"pred": p_sent, "prediction": prediction, "pred_idx": idx} | |||||
@@ -0,0 +1,36 @@ | |||||
import unittest | |||||
import sys | |||||
sys.path.append('..') | |||||
from data.dataloader import SummarizationLoader | |||||
vocab_size = 100000 | |||||
vocab_path = "testdata/vocab" | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
class TestSummarizationLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps) | |||||
print(data.datasets) | |||||
def test_case2(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, domain=True) | |||||
print(data.datasets, data.vocabs) | |||||
def test_case3(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, tag=True) | |||||
print(data.datasets, data.vocabs) | |||||
@@ -0,0 +1,36 @@ | |||||
import unittest | |||||
import sys | |||||
sys.path.append('..') | |||||
from data.dataloader import SummarizationLoader | |||||
vocab_size = 100000 | |||||
vocab_path = "testdata/vocab" | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
class TestSummarizationLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps) | |||||
print(data.datasets) | |||||
def test_case2(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, domain=True) | |||||
print(data.datasets, data.vocabs) | |||||
def test_case3(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, tag=True) | |||||
print(data.datasets, data.vocabs) | |||||
@@ -0,0 +1,56 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import os | |||||
import sys | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | |||||
from fastNLP.core.const import Const | |||||
from data.dataloader import SummarizationLoader | |||||
from tools.data import ExampleSet, Vocab | |||||
vocab_size = 100000 | |||||
vocab_path = "test/testdata/vocab" | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
# paths = {"train": "test/testdata/train.jsonl", "valid": "test/testdata/val.jsonl"} | |||||
paths = {"train": "/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl", "valid": "/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl"} | |||||
sum_loader = SummarizationLoader() | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, load_vocab_file=True) | |||||
trainset = dataInfo.datasets["train"] | |||||
vocab = Vocab(vocab_path, vocab_size) | |||||
dataset = ExampleSet(paths["train"], vocab, doc_max_timesteps, sent_max_len) | |||||
# print(trainset[0]["text"]) | |||||
# print(dataset.get_example(0).original_article_sents) | |||||
# print(trainset[0]["words"]) | |||||
# print(dataset[0][0].numpy().tolist()) | |||||
b_size = len(trainset) | |||||
for i in range(b_size): | |||||
if i <= 7327: | |||||
continue | |||||
print(trainset[i][Const.INPUT]) | |||||
print(dataset[i][0].numpy().tolist()) | |||||
assert trainset[i][Const.INPUT] == dataset[i][0].numpy().tolist(), i | |||||
assert trainset[i][Const.INPUT_LEN] == dataset[i][2].numpy().tolist(), i | |||||
assert trainset[i][Const.TARGET] == dataset[i][1].numpy().tolist(), i |
@@ -0,0 +1,135 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import os | |||||
import sys | |||||
import time | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from fastNLP.core.callback import Callback, EarlyStopError | |||||
from tools.logger import * | |||||
class TrainCallback(Callback): | |||||
def __init__(self, hps, patience=3, quit_all=True): | |||||
super().__init__() | |||||
self._hps = hps | |||||
self.patience = patience | |||||
self.wait = 0 | |||||
if type(quit_all) != bool: | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||||
self.quit_all = quit_all | |||||
def on_epoch_begin(self): | |||||
self.epoch_start_time = time.time() | |||||
# def on_loss_begin(self, batch_y, predict_y): | |||||
# """ | |||||
# | |||||
# :param batch_y: dict | |||||
# input_len: [batch, N] | |||||
# :param predict_y: dict | |||||
# p_sent: [batch, N, 2] | |||||
# :return: | |||||
# """ | |||||
# input_len = batch_y[Const.INPUT_LEN] | |||||
# batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) | |||||
# # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) | |||||
# # logger.debug(predict_y["p_sent"][0:5,:,:]) | |||||
def on_backward_begin(self, loss): | |||||
""" | |||||
:param loss: [] | |||||
:return: | |||||
""" | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
def on_backward_end(self): | |||||
if self._hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | |||||
def on_epoch_end(self): | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' | |||||
.format(self.epoch, (time.time() - self.epoch_start_time))) | |||||
def on_valid_begin(self): | |||||
self.valid_start_time = time.time() | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | |||||
.format(self.epoch, (time.time() - self.valid_start_time))) | |||||
# early stop | |||||
if not is_better_eval: | |||||
if self.wait == self.patience: | |||||
train_dir = os.path.join(self._hps.save_root, "train") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
raise EarlyStopError("Early stopping raised.") | |||||
else: | |||||
self.wait += 1 | |||||
else: | |||||
self.wait = 0 | |||||
# lr descent | |||||
if self._hps.lr_descent: | |||||
new_lr = max(5e-6, self._hps.lr / (self.epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, KeyboardInterrupt): | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
train_dir = os.path.join(self._hps.save_root, "train") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
if self.quit_all is True: | |||||
sys.exit(0) # 直接退出程序 | |||||
else: | |||||
pass | |||||
else: | |||||
raise exception # 抛出陌生Error | |||||
@@ -0,0 +1,562 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import * | |||||
import torch.nn.init as init | |||||
import data | |||||
from tools.logger import * | |||||
from transformer.Models import get_sinusoid_encoding_table | |||||
class Encoder(nn.Module): | |||||
def __init__(self, hps, vocab): | |||||
super(Encoder, self).__init__() | |||||
self._hps = hps | |||||
self._vocab = vocab | |||||
self.sent_max_len = hps.sent_max_len | |||||
vocab_size = len(vocab) | |||||
logger.info("[INFO] Vocabulary size is %d", vocab_size) | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]')) | |||||
if hps.word_embedding: | |||||
word2vec = data.Word_Embedding(hps.embedding_path, vocab) | |||||
word_vecs = word2vec.load_my_vecs(embed_size) | |||||
# pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size) | |||||
pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size) | |||||
pretrained_weight = np.array(pretrained_weight) | |||||
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||||
self.embed.weight.requires_grad = hps.embed_train | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len] | |||||
vocab = self._vocab | |||||
batch_size, N, _ = input.size() | |||||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||||
input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1] | |||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class DomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(DomainEncoder, self).__init__(hps, vocab) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
sent_embedding = super().forward(input) | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class MultiDomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||||
self.domain_size = domaindict.size() | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size, domain_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
# logger.info(domain[:5, :]) | |||||
sent_embedding = super().forward(input) | |||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, hps): | |||||
super(BertEncoder, self).__init__() | |||||
from pytorch_pretrained_bert.modeling import BertModel | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
self._cuda = hps.cuda | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||||
self._bert.eval() | |||||
for p in self._bert.parameters(): | |||||
p.requires_grad = False | |||||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
for i in range(batch_size): | |||||
tokens_id = inputs[i] | |||||
input_mask = input_masks[i] | |||||
sent_len = enc_sent_len[i] | |||||
input_ids = tokens_id.unsqueeze(0) | |||||
input_mask = input_mask.unsqueeze(0) | |||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 1 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 511: | |||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class BertTagEncoder(BertEncoder): | |||||
def __init__(self, hps, domaindict): | |||||
super(BertTagEncoder, self).__init__(hps) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||||
batch_size, N = enc_sent_len.size() | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class ELMoEndoer(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||||
batch_size, N, seq_len, _ = input.size() | |||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
logger.debug(input_sent_len.view(batch_size, -1)) | |||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||||
enc_embed_input = self.embed_proj(enc_embed_input) | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in input_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class ELMoEndoer2(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer2, self).__init__() | |||||
self._hps = hps | |||||
self._cuda = hps.cuda | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
sent_embedding: [batch, N, D] | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||||
# print("END elmo") | |||||
for i in range(batch_size): | |||||
sent_len = enc_sent_len[i] # [1, N] | |||||
out = elmo_output[i] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 0 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 512: | |||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding |
@@ -0,0 +1,41 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import torch | |||||
import numpy as np | |||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
''' Sinusoid position encoding table ''' | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) |
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,479 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" | |||||
import os | |||||
import re | |||||
import glob | |||||
import copy | |||||
import random | |||||
import json | |||||
import collections | |||||
from itertools import combinations | |||||
import numpy as np | |||||
from random import shuffle | |||||
import torch.utils.data | |||||
import time | |||||
import pickle | |||||
from nltk.tokenize import sent_tokenize | |||||
import utils | |||||
from logger import * | |||||
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | |||||
SENTENCE_START = '<s>' | |||||
SENTENCE_END = '</s>' | |||||
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence | |||||
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words | |||||
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence | |||||
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences | |||||
# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file. | |||||
class Vocab(object): | |||||
"""Vocabulary class for mapping between words and ids (integers)""" | |||||
def __init__(self, vocab_file, max_size): | |||||
""" | |||||
Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. | |||||
:param vocab_file: string; path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. | |||||
:param max_size: int; The maximum size of the resulting Vocabulary. | |||||
""" | |||||
self._word_to_id = {} | |||||
self._id_to_word = {} | |||||
self._count = 0 # keeps track of total number of words in the Vocab | |||||
# [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. | |||||
for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]: | |||||
self._word_to_id[w] = self._count | |||||
self._id_to_word[self._count] = w | |||||
self._count += 1 | |||||
# Read the vocab file and add words up to max_size | |||||
with open(vocab_file, 'r', encoding='utf8') as vocab_f: #New : add the utf8 encoding to prevent error | |||||
cnt = 0 | |||||
for line in vocab_f: | |||||
cnt += 1 | |||||
pieces = line.split("\t") | |||||
# pieces = line.split() | |||||
w = pieces[0] | |||||
# print(w) | |||||
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: | |||||
raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) | |||||
if w in self._word_to_id: | |||||
logger.error('Duplicated word in vocabulary file Line %d : %s' % (cnt, w)) | |||||
continue | |||||
self._word_to_id[w] = self._count | |||||
self._id_to_word[self._count] = w | |||||
self._count += 1 | |||||
if max_size != 0 and self._count >= max_size: | |||||
logger.info("[INFO] max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) | |||||
break | |||||
logger.info("[INFO] Finished constructing vocabulary of %i total words. Last word added: %s", self._count, self._id_to_word[self._count-1]) | |||||
def word2id(self, word): | |||||
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" | |||||
if word not in self._word_to_id: | |||||
return self._word_to_id[UNKNOWN_TOKEN] | |||||
return self._word_to_id[word] | |||||
def id2word(self, word_id): | |||||
"""Returns the word (string) corresponding to an id (integer).""" | |||||
if word_id not in self._id_to_word: | |||||
raise ValueError('Id not found in vocab: %d' % word_id) | |||||
return self._id_to_word[word_id] | |||||
def size(self): | |||||
"""Returns the total size of the vocabulary""" | |||||
return self._count | |||||
def word_list(self): | |||||
"""Return the word list of the vocabulary""" | |||||
return self._word_to_id.keys() | |||||
class Word_Embedding(object): | |||||
def __init__(self, path, vocab): | |||||
""" | |||||
:param path: string; the path of word embedding | |||||
:param vocab: object; | |||||
""" | |||||
logger.info("[INFO] Loading external word embedding...") | |||||
self._path = path | |||||
self._vocablist = vocab.word_list() | |||||
self._vocab = vocab | |||||
def load_my_vecs(self, k=200): | |||||
"""Load word embedding""" | |||||
word_vecs = {} | |||||
with open(self._path, encoding="utf-8") as f: | |||||
count = 0 | |||||
lines = f.readlines()[1:] | |||||
for line in lines: | |||||
values = line.split(" ") | |||||
word = values[0] | |||||
count += 1 | |||||
if word in self._vocablist: # whether to judge if in vocab | |||||
vector = [] | |||||
for count, val in enumerate(values): | |||||
if count == 0: | |||||
continue | |||||
if count <= k: | |||||
vector.append(float(val)) | |||||
word_vecs[word] = vector | |||||
return word_vecs | |||||
def add_unknown_words_by_zero(self, word_vecs, k=200): | |||||
"""Solve unknown by zeros""" | |||||
zero = [0.0] * k | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = zero | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||||
return list_word2vec | |||||
def add_unknown_words_by_avg(self, word_vecs, k=200): | |||||
"""Solve unknown by avg word embedding""" | |||||
# solve unknown words inplaced by zero list | |||||
word_vecs_numpy = [] | |||||
for word in self._vocablist: | |||||
if word in word_vecs: | |||||
word_vecs_numpy.append(word_vecs[word]) | |||||
col = [] | |||||
for i in range(k): | |||||
sum = 0.0 | |||||
for j in range(int(len(word_vecs_numpy))): | |||||
sum += word_vecs_numpy[j][i] | |||||
sum = round(sum, 6) | |||||
col.append(sum) | |||||
zero = [] | |||||
for m in range(k): | |||||
avg = col[m] / int(len(word_vecs_numpy)) | |||||
avg = round(avg, 6) | |||||
zero.append(float(avg)) | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = zero | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] External Word Embedding iov count: %d, oov count: %d", iov, oov) | |||||
return list_word2vec | |||||
def add_unknown_words_by_uniform(self, word_vecs, uniform=0.25, k=200): | |||||
"""Solve unknown word by uniform(-0.25,0.25)""" | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = np.random.uniform(-1 * uniform, uniform, k).round(6).tolist() | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||||
return list_word2vec | |||||
# load word embedding | |||||
def load_my_vecs_freq1(self, freqs, pro): | |||||
word_vecs = {} | |||||
with open(self._path, encoding="utf-8") as f: | |||||
freq = 0 | |||||
lines = f.readlines()[1:] | |||||
for line in lines: | |||||
values = line.split(" ") | |||||
word = values[0] | |||||
if word in self._vocablist: # whehter to judge if in vocab | |||||
if freqs[word] == 1: | |||||
a = np.random.uniform(0, 1, 1).round(2) | |||||
if pro < a: | |||||
continue | |||||
vector = [] | |||||
for count, val in enumerate(values): | |||||
if count == 0: | |||||
continue | |||||
vector.append(float(val)) | |||||
word_vecs[word] = vector | |||||
return word_vecs | |||||
class DomainDict(object): | |||||
"""Domain embedding for Newsroom""" | |||||
def __init__(self, path): | |||||
self.domain_list = self.readDomainlist(path) | |||||
# self.domain_list = ["foxnews.com", "cnn.com", "mashable.com", "nytimes.com", "washingtonpost.com"] | |||||
self.domain_number = len(self.domain_list) | |||||
self._domain_to_id = {} | |||||
self._id_to_domain = {} | |||||
self._cnt = 0 | |||||
self._domain_to_id["X"] = self._cnt | |||||
self._id_to_domain[self._cnt] = "X" | |||||
self._cnt += 1 | |||||
for i in range(self.domain_number): | |||||
domain = self.domain_list[i] | |||||
self._domain_to_id[domain] = self._cnt | |||||
self._id_to_domain[self._cnt] = domain | |||||
self._cnt += 1 | |||||
def readDomainlist(self, path): | |||||
domain_list = [] | |||||
with open(path) as f: | |||||
for line in f: | |||||
domain_list.append(line.split("\t")[0].strip()) | |||||
logger.info(domain_list) | |||||
return domain_list | |||||
def domain2id(self, domain): | |||||
""" Returns the id (integer) of a domain (string). Returns "X" for unknow domain. | |||||
:param domain: string | |||||
:return: id; int | |||||
""" | |||||
if domain in self.domain_list: | |||||
return self._domain_to_id[domain] | |||||
else: | |||||
logger.info(domain) | |||||
return self._domain_to_id["X"] | |||||
def id2domain(self, domain_id): | |||||
""" Returns the domain (string) corresponding to an id (integer). | |||||
:param id: int; | |||||
:return: domain: string | |||||
""" | |||||
if domain_id not in self._id_to_domain: | |||||
raise ValueError('Id not found in DomainDict: %d' % domain_id) | |||||
return self._id_to_domain[id] | |||||
def size(self): | |||||
return self._cnt | |||||
class Example(object): | |||||
"""Class representing a train/val/test example for text summarization.""" | |||||
def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label, domainid=None): | |||||
""" Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. | |||||
:param article_sents: list of strings; one per article sentence. each token is separated by a single space. | |||||
:param abstract_sents: list of strings; one per abstract sentence. In each sentence, each token is separated by a single space. | |||||
:param domainid: int; publication of the example | |||||
:param vocab: Vocabulary object | |||||
:param sent_max_len: int; the maximum length of each sentence, padding all sentences to this length | |||||
:param label: list of int; the index of selected sentences | |||||
""" | |||||
self.sent_max_len = sent_max_len | |||||
self.enc_sent_len = [] | |||||
self.enc_sent_input = [] | |||||
self.enc_sent_input_pad = [] | |||||
# origin_cnt = len(article_sents) | |||||
# article_sents = [re.sub(r"\n+\t+", " ", sent) for sent in article_sents] | |||||
# assert origin_cnt == len(article_sents) | |||||
# Process the article | |||||
for sent in article_sents: | |||||
article_words = sent.split() | |||||
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | |||||
self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
self._pad_encoder_input(vocab.word2id('[PAD]')) | |||||
# Store the original strings | |||||
self.original_article = " ".join(article_sents) | |||||
self.original_article_sents = article_sents | |||||
if isinstance(abstract_sents[0], list): | |||||
logger.debug("[INFO] Multi Reference summaries!") | |||||
self.original_abstract_sents = [] | |||||
self.original_abstract = [] | |||||
for summary in abstract_sents: | |||||
self.original_abstract_sents.append([sent.strip() for sent in summary]) | |||||
self.original_abstract.append("\n".join([sent.replace("\n", "") for sent in summary])) | |||||
else: | |||||
self.original_abstract_sents = [sent.replace("\n", "") for sent in abstract_sents] | |||||
self.original_abstract = "\n".join(self.original_abstract_sents) | |||||
# Store the label | |||||
self.label = np.zeros(len(article_sents), dtype=int) | |||||
if label != []: | |||||
self.label[np.array(label)] = 1 | |||||
self.label = list(self.label) | |||||
# Store the publication | |||||
if domainid != None: | |||||
if domainid == 0: | |||||
logger.debug("domain id = 0!") | |||||
self.domain = domainid | |||||
def _pad_encoder_input(self, pad_id): | |||||
""" | |||||
:param pad_id: int; token pad id | |||||
:return: | |||||
""" | |||||
max_len = self.sent_max_len | |||||
for i in range(len(self.enc_sent_input)): | |||||
article_words = self.enc_sent_input[i] | |||||
if len(article_words) > max_len: | |||||
article_words = article_words[:max_len] | |||||
while len(article_words) < max_len: | |||||
article_words.append(pad_id) | |||||
self.enc_sent_input_pad.append(article_words) | |||||
class ExampleSet(torch.utils.data.Dataset): | |||||
""" Constructor: Dataset of example(object) """ | |||||
def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||||
""" Initializes the ExampleSet with the path of data | |||||
:param data_path: string; the path of data | |||||
:param vocab: object; | |||||
:param doc_max_timesteps: int; the maximum sentence number of a document, each example should pad sentences to this length | |||||
:param sent_max_len: int; the maximum token number of a sentence, each sentence should pad tokens to this length | |||||
:param domaindict: object; the domain dict to embed domain | |||||
""" | |||||
self.domaindict = domaindict | |||||
if domaindict: | |||||
logger.info("[INFO] Use domain information in the dateset!") | |||||
if randomX==True: | |||||
logger.info("[INFO] Random some example to unknow domain X!") | |||||
self.randomP = 0.1 | |||||
logger.info("[INFO] Start reading ExampleSet") | |||||
start = time.time() | |||||
self.example_list = [] | |||||
self.doc_max_timesteps = doc_max_timesteps | |||||
cnt = 0 | |||||
with open(data_path, 'r') as reader: | |||||
for line in reader: | |||||
try: | |||||
e = json.loads(line) | |||||
article_sent = e['text'] | |||||
tag = e["tag"][0] if usetag else e['publication'] | |||||
# logger.info(tag) | |||||
if "duc" in data_path: | |||||
abstract_sent = e["summaryList"] if "summaryList" in e.keys() else [e['summary']] | |||||
else: | |||||
abstract_sent = e['summary'] | |||||
if domaindict: | |||||
if randomX == True: | |||||
p = np.random.rand() | |||||
if p <= self.randomP: | |||||
domainid = domaindict.domain2id("X") | |||||
else: | |||||
domainid = domaindict.domain2id(tag) | |||||
else: | |||||
domainid = domaindict.domain2id(tag) | |||||
else: | |||||
domainid = None | |||||
logger.debug((tag, domainid)) | |||||
except (ValueError,EOFError) as e : | |||||
logger.debug(e) | |||||
break | |||||
else: | |||||
example = Example(article_sent, abstract_sent, vocab, sent_max_len, e["label"], domainid) # Process into an Example. | |||||
self.example_list.append(example) | |||||
cnt += 1 | |||||
# print(cnt) | |||||
logger.info("[INFO] Finish reading ExampleSet. Total time is %f, Total size is %d", time.time() - start, len(self.example_list)) | |||||
self.size = len(self.example_list) | |||||
# self.example_list.sort(key=lambda ex: ex.domain) | |||||
def get_example(self, index): | |||||
return self.example_list[index] | |||||
def __getitem__(self, index): | |||||
""" | |||||
:param index: int; the index of the example | |||||
:return | |||||
input_pad: [N, seq_len] | |||||
label: [N] | |||||
input_mask: [N] | |||||
domain: [1] | |||||
""" | |||||
item = self.example_list[index] | |||||
input = np.array(item.enc_sent_input_pad) | |||||
label = np.array(item.label, dtype=int) | |||||
# pad input to doc_max_timesteps | |||||
if len(input) < self.doc_max_timesteps: | |||||
pad_number = self.doc_max_timesteps - len(input) | |||||
pad_matrix = np.zeros((pad_number, len(input[0]))) | |||||
input_pad = np.vstack((input, pad_matrix)) | |||||
label = np.append(label, np.zeros(pad_number, dtype=int)) | |||||
input_mask = np.append(np.ones(len(input)), np.zeros(pad_number)) | |||||
else: | |||||
input_pad = input[:self.doc_max_timesteps] | |||||
label = label[:self.doc_max_timesteps] | |||||
input_mask = np.ones(self.doc_max_timesteps) | |||||
if self.domaindict: | |||||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long(), item.domain | |||||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long() | |||||
def __len__(self): | |||||
return self.size | |||||
class MultiExampleSet(): | |||||
def __init__(self, data_dir, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||||
self.datasets = [None] * (domaindict.size() - 1) | |||||
data_path_list = [os.path.join(data_dir, s) for s in os.listdir(data_dir) if s.endswith("label.jsonl")] | |||||
for data_path in data_path_list: | |||||
fname = data_path.split("/")[-1] # cnn.com.label.json | |||||
dataname = ".".join(fname.split(".")[:-2]) | |||||
domainid = domaindict.domain2id(dataname) | |||||
logger.info("[INFO] domain name: %s, domain id: %d" % (dataname, domainid)) | |||||
self.datasets[domainid - 1] = ExampleSet(data_path, vocab, doc_max_timesteps, sent_max_len, domaindict, randomX, usetag) | |||||
def get(self, id): | |||||
return self.datasets[id] | |||||
from torch.utils.data.dataloader import default_collate | |||||
def my_collate_fn(batch): | |||||
''' | |||||
:param batch: (input_pad, label, input_mask, domain) | |||||
:return: | |||||
''' | |||||
start_domain = batch[0][-1] | |||||
# for i in range(len(batch)): | |||||
# print(batch[i][-1], end=',') | |||||
batch = list(filter(lambda x: x[-1] == start_domain, batch)) | |||||
print("start_domain %d" % start_domain) | |||||
print("batch_len %d" % len(batch)) | |||||
if len(batch) == 0: return torch.Tensor() | |||||
return default_collate(batch) # 用默认方式拼接过滤后的batch数据 | |||||
@@ -0,0 +1,27 @@ | |||||
# -*- coding: utf-8 -*- | |||||
import logging | |||||
import sys | |||||
# 获取logger实例,如果参数为空则返回root logger | |||||
logger = logging.getLogger("Summarization logger") | |||||
# logger = logging.getLogger() | |||||
# 指定logger输出格式 | |||||
formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') | |||||
# # 文件日志 | |||||
# file_handler = logging.FileHandler("test.log") | |||||
# file_handler.setFormatter(formatter) # 可以通过setFormatter指定输出格式 | |||||
# 控制台日志 | |||||
console_handler = logging.StreamHandler(sys.stdout) | |||||
console_handler.formatter = formatter # 也可以直接给formatter赋值 | |||||
console_handler.setLevel(logging.INFO) | |||||
# 为logger添加的日志处理器 | |||||
# logger.addHandler(file_handler) | |||||
logger.addHandler(console_handler) | |||||
# 指定日志的最低输出级别,默认为WARN级别 | |||||
logger.setLevel(logging.DEBUG) |
@@ -0,0 +1,297 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
import re | |||||
import os | |||||
import shutil | |||||
import copy | |||||
import datetime | |||||
import numpy as np | |||||
from rouge import Rouge | |||||
from .logger import * | |||||
# from data import * | |||||
import sys | |||||
sys.setrecursionlimit(10000) | |||||
REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", | |||||
"-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} | |||||
def clean(x): | |||||
return re.sub( | |||||
r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", | |||||
lambda m: REMAP.get(m.group()), x) | |||||
def rouge_eval(hyps, refer): | |||||
rouge = Rouge() | |||||
# print(hyps) | |||||
# print(refer) | |||||
# print(rouge.get_scores(hyps, refer)) | |||||
try: | |||||
score = rouge.get_scores(hyps, refer)[0] | |||||
mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||||
except: | |||||
mean_score = 0.0 | |||||
return mean_score | |||||
def rouge_all(hyps, refer): | |||||
rouge = Rouge() | |||||
score = rouge.get_scores(hyps, refer)[0] | |||||
# mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||||
return score | |||||
def eval_label(match_true, pred, true, total, match): | |||||
match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float() | |||||
try: | |||||
accu = match / total | |||||
precision = match_true / pred | |||||
recall = match_true / true | |||||
F = 2 * precision * recall / (precision + recall) | |||||
except ZeroDivisionError: | |||||
F = 0.0 | |||||
logger.error("[Error] float division by zero") | |||||
return accu, precision, recall, F | |||||
def pyrouge_score(hyps, refer, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
if remap == True: | |||||
refer = clean(refer) | |||||
hyps = clean(hyps) | |||||
system_file = os.path.join(SYSTEM_PATH, 'Reference.0.txt') | |||||
model_file = os.path.join(MODEL_PATH, 'Model.A.0.txt') | |||||
with open(system_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
with open(model_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
r = Rouge155('/home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def pyrouge_score_all(hyps_list, refer_list, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
assert len(hyps_list) == len(refer_list) | |||||
for i in range(len(hyps_list)): | |||||
system_file = os.path.join(SYSTEM_PATH, 'Reference.%d.txt' % i) | |||||
model_file = os.path.join(MODEL_PATH, 'Model.A.%d.txt' % i) | |||||
refer = clean(refer_list[i]) if remap else refer_list[i] | |||||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||||
with open(system_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
with open(model_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def pyrouge_score_all_multi(hyps_list, refer_list, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT, 'system') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT, 'gold') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
assert len(hyps_list) == len(refer_list) | |||||
for i in range(len(hyps_list)): | |||||
system_file = os.path.join(SYSTEM_PATH, 'Model.%d.txt' % i) | |||||
# model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i) | |||||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||||
with open(system_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
referType = ["A", "B", "C", "D", "E", "F", "G"] | |||||
for j in range(len(refer_list[i])): | |||||
model_file = os.path.join(MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i)) | |||||
refer = clean(refer_list[i][j]) if remap else refer_list[i][j] | |||||
with open(model_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Model.(\d+).txt' | |||||
r.model_filename_pattern = 'Reference.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def cal_label(article, abstract): | |||||
hyps_list = article | |||||
refer = abstract | |||||
scores = [] | |||||
for hyps in hyps_list: | |||||
mean_score = rouge_eval(hyps, refer) | |||||
scores.append(mean_score) | |||||
selected = [] | |||||
selected.append(int(np.argmax(scores))) | |||||
selected_sent_cnt = 1 | |||||
best_rouge = np.max(scores) | |||||
while selected_sent_cnt < len(hyps_list): | |||||
cur_max_rouge = 0.0 | |||||
cur_max_idx = -1 | |||||
for i in range(len(hyps_list)): | |||||
if i not in selected: | |||||
temp = copy.deepcopy(selected) | |||||
temp.append(i) | |||||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||||
cur_rouge = rouge_eval(hyps, refer) | |||||
if cur_rouge > cur_max_rouge: | |||||
cur_max_rouge = cur_rouge | |||||
cur_max_idx = i | |||||
if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge: | |||||
selected.append(cur_max_idx) | |||||
selected_sent_cnt += 1 | |||||
best_rouge = cur_max_rouge | |||||
else: | |||||
break | |||||
# label = np.zeros(len(hyps_list), dtype=int) | |||||
# label[np.array(selected)] = 1 | |||||
# return list(label) | |||||
return selected | |||||
def cal_label_limited3(article, abstract): | |||||
hyps_list = article | |||||
refer = abstract | |||||
scores = [] | |||||
for hyps in hyps_list: | |||||
try: | |||||
mean_score = rouge_eval(hyps, refer) | |||||
scores.append(mean_score) | |||||
except ValueError: | |||||
scores.append(0.0) | |||||
selected = [] | |||||
selected.append(np.argmax(scores)) | |||||
selected_sent_cnt = 1 | |||||
best_rouge = np.max(scores) | |||||
while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3: | |||||
cur_max_rouge = 0.0 | |||||
cur_max_idx = -1 | |||||
for i in range(len(hyps_list)): | |||||
if i not in selected: | |||||
temp = copy.deepcopy(selected) | |||||
temp.append(i) | |||||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||||
cur_rouge = rouge_eval(hyps, refer) | |||||
if cur_rouge > cur_max_rouge: | |||||
cur_max_rouge = cur_rouge | |||||
cur_max_idx = i | |||||
selected.append(cur_max_idx) | |||||
selected_sent_cnt += 1 | |||||
best_rouge = cur_max_rouge | |||||
# logger.info(selected) | |||||
# label = np.zeros(len(hyps_list), dtype=int) | |||||
# label[np.array(selected)] = 1 | |||||
# return list(label) | |||||
return selected | |||||
import torch | |||||
def flip(x, dim): | |||||
xsize = x.size() | |||||
dim = x.dim() + dim if dim < 0 else dim | |||||
x = x.contiguous() | |||||
x = x.view(-1, *xsize[dim:]).contiguous() | |||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, | |||||
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] | |||||
return x.view(xsize) | |||||
def get_attn_key_pad_mask(seq_k, seq_q): | |||||
''' For masking out the padding part of key sequence. ''' | |||||
# Expand to fit the shape of key query attention matrix. | |||||
len_q = seq_q.size(1) | |||||
padding_mask = seq_k.eq(0.0) | |||||
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |||||
return padding_mask | |||||
def get_non_pad_mask(seq): | |||||
assert seq.dim() == 2 | |||||
return seq.ne(0.0).type(torch.float).unsqueeze(-1) |
@@ -0,0 +1,263 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import json | |||||
import argparse | |||||
import datetime | |||||
import torch | |||||
import torch.nn | |||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.trainer import Trainer, Tester | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
# from model.TransformerModel import TransformerModel | |||||
from model.TForiginal import TransformerModel | |||||
from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric | |||||
from model.Loss import MyCrossEntropyLoss | |||||
from tools.Callback import TrainCallback | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
lr = hps.lr | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | |||||
# criterion = torch.nn.CrossEntropyLoss(reduce="none") | |||||
trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | |||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
metric_key="f", validate_every=-1, save_path=eval_dir, | |||||
callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | |||||
train_info = trainer.train(load_best_model=True) | |||||
logger.info(' | end of Train | time: {:5.2f}s | '.format(train_info["seconds"])) | |||||
logger.info('[INFO] best eval model in epoch %d and iter %d', train_info["best_epoch"], train_info["best_step"]) | |||||
logger.info(train_info["best_eval"]) | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
if hps.use_pyrouge: | |||||
logger.info("[INFO] Use PyRougeMetric for testing") | |||||
tester = Tester(data=loader, model=model, | |||||
metrics=[LabelFMetric(pred="prediction"), PyRougeMetric(hps, pred="prediction")], | |||||
batch_size=hps.batch_size) | |||||
else: | |||||
logger.info("[INFO] Use FastRougeMetric for testing") | |||||
tester = Tester(data=loader, model=model, | |||||
metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
batch_size=hps.batch_size) | |||||
test_info = tester.test() | |||||
logger.info(test_info) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Summarization Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
# Important settings | |||||
parser.add_argument('--mode', choices=['train', 'test'], default='train', help='must be one of train/test') | |||||
parser.add_argument('--embedding', type=str, default='glove', choices=['word2vec', 'glove', 'elmo', 'bert'], help='must be one of word2vec/glove/elmo/bert') | |||||
parser.add_argument('--sentence_encoder', type=str, default='transformer', choices=['bilstm', 'deeplstm', 'transformer'], help='must be one of LSTM/Transformer') | |||||
parser.add_argument('--sentence_decoder', type=str, default='SeqLab', choices=['PN', 'SeqLab'], help='must be one of PN/SeqLab') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | |||||
# test | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if args.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
sum_loader = SummarizationLoader() | |||||
hps = args | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
if args.embedding == "glove": | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | |||||
if hps.word_embedding: | |||||
embed_loader = EmbedLoader() | |||||
pretrained_weight = embed_loader.load_with_vocab(hps.embedding_path, vocab) # unfound with random init | |||||
embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||||
embed.weight.requires_grad = hps.embed_train | |||||
else: | |||||
logger.error("[ERROR] embedding To Be Continued!") | |||||
sys.exit(1) | |||||
if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | |||||
model_param = json.load(open("config/transformer.config", "rb")) | |||||
hps.__dict__.update(model_param) | |||||
model = TransformerModel(hps, embed) | |||||
else: | |||||
logger.error("[ERROR] Model To Be Continued!") | |||||
sys.exit(1) | |||||
logger.info(hps) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
dataInfo.datasets["valid"].set_target("text", "summary") | |||||
setup_training(model, dataInfo.datasets["train"], dataInfo.datasets["valid"], hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
dataInfo.datasets["test"].set_target("text", "summary") | |||||
run_test(model, dataInfo.datasets["test"], hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,706 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import time | |||||
import copy | |||||
import pickle | |||||
import datetime | |||||
import argparse | |||||
import logging | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from rouge import Rouge | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.batch import DataSetIter | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from tools import utils | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
from model.TForiginal import TransformerModel | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
lr = hps.lr | |||||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||||
# eps=1e-09) | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
best_train_loss = None | |||||
best_train_F= None | |||||
best_loss = None | |||||
best_F = None | |||||
step_num = 0 | |||||
non_descent_cnt = 0 | |||||
for epoch in range(1, hps.n_epochs + 1): | |||||
epoch_loss = 0.0 | |||||
train_loss = 0.0 | |||||
total_example_num = 0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
epoch_start_time = time.time() | |||||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||||
# if i > 10: | |||||
# break | |||||
model.train() | |||||
iter_start_time=time.time() | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
# logger.info(batch_x["text"][0]) | |||||
# logger.info(input[0,:,:]) | |||||
# logger.info(input_len[0:5,:]) | |||||
# logger.info(batch_y["summary"][0:5]) | |||||
# logger.info(label[0:5,:]) | |||||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||||
batch_size, N, seq_len = input.size() | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
input = Variable(input) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs["p_sent"].view(-1, 2) | |||||
label = label.view(-1) | |||||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||||
# input_len = input_len.float().view(-1) | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.masked_fill(input_len.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
if hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||||
optimizer.step() | |||||
step_num += 1 | |||||
train_loss += float(loss.data) | |||||
epoch_loss += float(loss.data) | |||||
if i % 100 == 0: | |||||
# start debugger | |||||
# import pdb; pdb.set_trace() | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.debug(name) | |||||
logger.debug(param.grad.data.sum()) | |||||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||||
.format(i, (time.time() - iter_start_time), | |||||
float(train_loss / 100))) | |||||
train_loss = 0.0 | |||||
# calculate the precision, recall and F | |||||
prediction = outputs.max(1)[1] | |||||
prediction = prediction.data | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += int(batch_size * N) | |||||
if hps.lr_descent: | |||||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||||
# step_num * pow(hps.warmup_steps, -1.5)) | |||||
new_lr = max(5e-6, lr / (epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
epoch_avg_loss = epoch_loss / len(train_loader) | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||||
.format(epoch, (time.time() - epoch_start_time), | |||||
float(epoch_avg_loss))) | |||||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_loss = epoch_avg_loss | |||||
elif epoch_avg_loss > best_train_loss: | |||||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
if not best_train_F or F > best_train_F: | |||||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_F = F | |||||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||||
if non_descent_cnt >= 3: | |||||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
logger.info("[INFO] Starting eval for this model ...") | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
model.eval() | |||||
running_loss = 0.0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
total_example_num = 0 | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
iter_start_time = time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
# if i > 10: | |||||
# break | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input, requires_grad=False) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
outputs = model_outputs["p_sent"] | |||||
prediction = model_outputs["prediction"] | |||||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||||
label = label.view(-1) # [batch * N] | |||||
loss = criterion(outputs, label) | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.masked_fill(input_len.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
running_loss += float(loss.data) | |||||
label = label.data.view(batch_size, -1) | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
# rouge | |||||
prediction = prediction.view(batch_size, -1) | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if sent_max_number < hps.m and len(hyps) <= 1: | |||||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
# logger.debug(prediction[j]) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("prediction:") | |||||
logger.debug(prediction[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
running_avg_loss = running_loss / len(loader) | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
logging.getLogger('global').setLevel(logging.WARNING) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# try: | |||||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# scores_all = [] | |||||
# for idx in range(len(pairs["hyps"])): | |||||
# try: | |||||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||||
# scores_all.append(scores) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||||
# finally: | |||||
# logger.error("During testing, some errors happen!") | |||||
# logger.error(len(scores_all)) | |||||
# exit(1) | |||||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||||
.format((time.time() - iter_start_time), | |||||
float(running_avg_loss))) | |||||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
logger.info(res) | |||||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||||
if best_loss is None or running_avg_loss < best_loss: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_loss is not None: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_loss = running_avg_loss | |||||
non_descent_cnt = 0 | |||||
else: | |||||
non_descent_cnt += 1 | |||||
if best_F is None or best_F < F: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_F is not None: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_F = F | |||||
return best_loss, best_F, non_descent_cnt | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "evalbestFmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "trainbestmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||||
elif hps.test_model == "trainbestFmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
import datetime | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||||
if hps.save_label: | |||||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||||
resfile = open(log_dir, "w") | |||||
else: | |||||
log_dir = os.path.join(test_dir, nowTime) | |||||
resfile = open(log_dir, "wb") | |||||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||||
model.eval() | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
total_example_num = 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
pred_list = [] | |||||
iter_start_time=time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
prediction = model_outputs["prediction"] | |||||
if hps.save_label: | |||||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||||
continue | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if limited: | |||||
k = len(refer.split()) | |||||
hyps = " ".join(hyps.split()[:k]) | |||||
logger.info((len(refer.split()),len(hyps.split()))) | |||||
resfile.write(b"Original_article:") | |||||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"Reference:") | |||||
if isinstance(refer, list): | |||||
for ref in refer: | |||||
resfile.write(ref.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b'*' * 40) | |||||
resfile.write(b"\n") | |||||
else: | |||||
resfile.write(refer.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"hypothesis:") | |||||
resfile.write(hyps.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
else: | |||||
try: | |||||
scores = utils.rouge_all(hyps, refer) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
except ValueError: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
# single example res writer | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||||
resfile.write(res.encode('utf-8')) | |||||
resfile.write(b'-' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.save_label: | |||||
import json | |||||
json.dump(pred_list, resfile) | |||||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||||
return | |||||
resfile.write(b"\n") | |||||
resfile.write(b'=' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# the whole model res writer | |||||
resfile.write(b"The total testset is:") | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
resfile.write(res.encode("utf-8")) | |||||
logger.info(res) | |||||
logger.info(' | end of test | time: {:5.2f}s | ' | |||||
.format((time.time() - iter_start_time))) | |||||
# label prediction | |||||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
resfile.write(res.encode('utf-8')) | |||||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Transformer Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
# Important settings | |||||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
hps = args | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if hps.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
logger.info(args) | |||||
logger.info(args) | |||||
sum_loader = SummarizationLoader() | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
model = TransformerModel(hps, vocab) | |||||
if len(hps.gpu) > 1: | |||||
gpuid = hps.gpu.split(',') | |||||
gpuid = [int(s) for s in gpuid] | |||||
model = nn.DataParallel(model,device_ids=gpuid) | |||||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
trainset = dataInfo.datasets["train"] | |||||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||||
train_batch = DataSetIter(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||||
validset = dataInfo.datasets["valid"] | |||||
validset.set_input("text", "summary") | |||||
valid_batch = DataSetIter(batch_size=hps.batch_size, dataset=validset) | |||||
setup_training(model, train_batch, valid_batch, hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
testset = dataInfo.datasets["test"] | |||||
testset.set_input("text", "summary") | |||||
test_batch = DataSetIter(batch_size=hps.batch_size, dataset=testset) | |||||
run_test(model, test_batch, hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,705 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import time | |||||
import copy | |||||
import pickle | |||||
import datetime | |||||
import argparse | |||||
import logging | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from rouge import Rouge | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from tools import utils | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
from model.TransformerModel import TransformerModel | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
lr = hps.lr | |||||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||||
# eps=1e-09) | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
best_train_loss = None | |||||
best_train_F= None | |||||
best_loss = None | |||||
best_F = None | |||||
step_num = 0 | |||||
non_descent_cnt = 0 | |||||
for epoch in range(1, hps.n_epochs + 1): | |||||
epoch_loss = 0.0 | |||||
train_loss = 0.0 | |||||
total_example_num = 0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
epoch_start_time = time.time() | |||||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||||
# if i > 10: | |||||
# break | |||||
model.train() | |||||
iter_start_time=time.time() | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
# logger.info(batch_x["text"][0]) | |||||
# logger.info(input[0,:,:]) | |||||
# logger.info(input_len[0:5,:]) | |||||
# logger.info(batch_y["summary"][0:5]) | |||||
# logger.info(label[0:5,:]) | |||||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||||
batch_size, N, seq_len = input.size() | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
input = Variable(input) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs[Const.OUTPUT].view(-1, 2) | |||||
label = label.view(-1) | |||||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||||
input_len = input_len.float().view(-1) | |||||
loss = loss * input_len | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.sum(1).mean() | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
if hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||||
optimizer.step() | |||||
step_num += 1 | |||||
train_loss += float(loss.data) | |||||
epoch_loss += float(loss.data) | |||||
if i % 100 == 0: | |||||
# start debugger | |||||
# import pdb; pdb.set_trace() | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.debug(name) | |||||
logger.debug(param.grad.data.sum()) | |||||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||||
.format(i, (time.time() - iter_start_time), | |||||
float(train_loss / 100))) | |||||
train_loss = 0.0 | |||||
# calculate the precision, recall and F | |||||
prediction = outputs.max(1)[1] | |||||
prediction = prediction.data | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += int(batch_size * N) | |||||
if hps.lr_descent: | |||||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||||
# step_num * pow(hps.warmup_steps, -1.5)) | |||||
new_lr = max(5e-6, lr / (epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
epoch_avg_loss = epoch_loss / len(train_loader) | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||||
.format(epoch, (time.time() - epoch_start_time), | |||||
float(epoch_avg_loss))) | |||||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_loss = epoch_avg_loss | |||||
elif epoch_avg_loss > best_train_loss: | |||||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
if not best_train_F or F > best_train_F: | |||||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_F = F | |||||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||||
if non_descent_cnt >= 3: | |||||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
logger.info("[INFO] Starting eval for this model ...") | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
model.eval() | |||||
running_loss = 0.0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
total_example_num = 0 | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
iter_start_time = time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
# if i > 10: | |||||
# break | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input, requires_grad=False) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
outputs = model_outputs[Const.OUTPUTS] | |||||
prediction = model_outputs["prediction"] | |||||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||||
label = label.view(-1) # [batch * N] | |||||
loss = criterion(outputs, label) | |||||
input_len = input_len.float().view(-1) | |||||
loss = loss * input_len | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.sum(1).mean() | |||||
running_loss += float(loss.data) | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
# rouge | |||||
prediction = prediction.view(batch_size, -1) | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if sent_max_number < hps.m and len(hyps) <= 1: | |||||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
# logger.debug(prediction[j]) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("prediction:") | |||||
logger.debug(prediction[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
running_avg_loss = running_loss / len(loader) | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
logging.getLogger('global').setLevel(logging.WARNING) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# try: | |||||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# scores_all = [] | |||||
# for idx in range(len(pairs["hyps"])): | |||||
# try: | |||||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||||
# scores_all.append(scores) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||||
# finally: | |||||
# logger.error("During testing, some errors happen!") | |||||
# logger.error(len(scores_all)) | |||||
# exit(1) | |||||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||||
.format((time.time() - iter_start_time), | |||||
float(running_avg_loss))) | |||||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
logger.info(res) | |||||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||||
if best_loss is None or running_avg_loss < best_loss: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_loss is not None: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_loss = running_avg_loss | |||||
non_descent_cnt = 0 | |||||
else: | |||||
non_descent_cnt += 1 | |||||
if best_F is None or best_F < F: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_F is not None: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_F = F | |||||
return best_loss, best_F, non_descent_cnt | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "evalbestFmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "trainbestmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||||
elif hps.test_model == "trainbestFmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
import datetime | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||||
if hps.save_label: | |||||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||||
resfile = open(log_dir, "w") | |||||
else: | |||||
log_dir = os.path.join(test_dir, nowTime) | |||||
resfile = open(log_dir, "wb") | |||||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||||
model.eval() | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
total_example_num = 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
pred_list = [] | |||||
iter_start_time=time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
prediction = model_outputs["pred"] | |||||
if hps.save_label: | |||||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||||
continue | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if limited: | |||||
k = len(refer.split()) | |||||
hyps = " ".join(hyps.split()[:k]) | |||||
logger.info((len(refer.split()),len(hyps.split()))) | |||||
resfile.write(b"Original_article:") | |||||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"Reference:") | |||||
if isinstance(refer, list): | |||||
for ref in refer: | |||||
resfile.write(ref.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b'*' * 40) | |||||
resfile.write(b"\n") | |||||
else: | |||||
resfile.write(refer.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"hypothesis:") | |||||
resfile.write(hyps.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
else: | |||||
try: | |||||
scores = utils.rouge_all(hyps, refer) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
except ValueError: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
# single example res writer | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||||
resfile.write(res.encode('utf-8')) | |||||
resfile.write(b'-' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.save_label: | |||||
import json | |||||
json.dump(pred_list, resfile) | |||||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||||
return | |||||
resfile.write(b"\n") | |||||
resfile.write(b'=' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# the whole model res writer | |||||
resfile.write(b"The total testset is:") | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
resfile.write(res.encode("utf-8")) | |||||
logger.info(res) | |||||
logger.info(' | end of test | time: {:5.2f}s | ' | |||||
.format((time.time() - iter_start_time))) | |||||
# label prediction | |||||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
resfile.write(res.encode('utf-8')) | |||||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Transformer Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
# Important settings | |||||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
hps = args | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if hps.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
logger.info(args) | |||||
logger.info(args) | |||||
sum_loader = SummarizationLoader() | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
model = TransformerModel(hps, vocab) | |||||
if len(hps.gpu) > 1: | |||||
gpuid = hps.gpu.split(',') | |||||
gpuid = [int(s) for s in gpuid] | |||||
model = nn.DataParallel(model,device_ids=gpuid) | |||||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
trainset = dataInfo.datasets["train"] | |||||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||||
train_batch = Batch(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||||
validset = dataInfo.datasets["valid"] | |||||
validset.set_input("text", "summary") | |||||
valid_batch = Batch(batch_size=hps.batch_size, dataset=validset) | |||||
setup_training(model, train_batch, valid_batch, hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
testset = dataInfo.datasets["test"] | |||||
testset.set_input("text", "summary") | |||||
test_batch = Batch(batch_size=hps.batch_size, dataset=testset) | |||||
run_test(model, test_batch, hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,103 @@ | |||||
""" Manage beam search info structure. | |||||
Heavily borrowed from OpenNMT-py. | |||||
For code in OpenNMT-py, please check the following link: | |||||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py | |||||
""" | |||||
import torch | |||||
import numpy as np | |||||
import transformer.Constants as Constants | |||||
class Beam(): | |||||
''' Beam search ''' | |||||
def __init__(self, size, device=False): | |||||
self.size = size | |||||
self._done = False | |||||
# The score for each translation on the beam. | |||||
self.scores = torch.zeros((size,), dtype=torch.float, device=device) | |||||
self.all_scores = [] | |||||
# The backpointers at each time-step. | |||||
self.prev_ks = [] | |||||
# The outputs at each time-step. | |||||
self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] | |||||
self.next_ys[0][0] = Constants.BOS | |||||
def get_current_state(self): | |||||
"Get the outputs for the current timestep." | |||||
return self.get_tentative_hypothesis() | |||||
def get_current_origin(self): | |||||
"Get the backpointers for the current timestep." | |||||
return self.prev_ks[-1] | |||||
@property | |||||
def done(self): | |||||
return self._done | |||||
def advance(self, word_prob): | |||||
"Update beam status and check if finished or not." | |||||
num_words = word_prob.size(1) | |||||
# Sum the previous scores. | |||||
if len(self.prev_ks) > 0: | |||||
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) | |||||
else: | |||||
beam_lk = word_prob[0] | |||||
flat_beam_lk = beam_lk.view(-1) | |||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort | |||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort | |||||
self.all_scores.append(self.scores) | |||||
self.scores = best_scores | |||||
# bestScoresId is flattened as a (beam x word) array, | |||||
# so we need to calculate which word and beam each score came from | |||||
prev_k = best_scores_id / num_words | |||||
self.prev_ks.append(prev_k) | |||||
self.next_ys.append(best_scores_id - prev_k * num_words) | |||||
# End condition is when top-of-beam is EOS. | |||||
if self.next_ys[-1][0].item() == Constants.EOS: | |||||
self._done = True | |||||
self.all_scores.append(self.scores) | |||||
return self._done | |||||
def sort_scores(self): | |||||
"Sort the scores." | |||||
return torch.sort(self.scores, 0, True) | |||||
def get_the_best_score_and_idx(self): | |||||
"Get the score of the best in the beam." | |||||
scores, ids = self.sort_scores() | |||||
return scores[1], ids[1] | |||||
def get_tentative_hypothesis(self): | |||||
"Get the decoded sequence for the current timestep." | |||||
if len(self.next_ys) == 1: | |||||
dec_seq = self.next_ys[0].unsqueeze(1) | |||||
else: | |||||
_, keys = self.sort_scores() | |||||
hyps = [self.get_hypothesis(k) for k in keys] | |||||
hyps = [[Constants.BOS] + h for h in hyps] | |||||
dec_seq = torch.LongTensor(hyps) | |||||
return dec_seq | |||||
def get_hypothesis(self, k): | |||||
""" Walk back to construct the full hypothesis. """ | |||||
hyp = [] | |||||
for j in range(len(self.prev_ks) - 1, -1, -1): | |||||
hyp.append(self.next_ys[j+1][k]) | |||||
k = self.prev_ks[j][k] | |||||
return list(map(lambda x: x.item(), hyp[::-1])) |
@@ -0,0 +1,10 @@ | |||||
PAD = 0 | |||||
UNK = 1 | |||||
BOS = 2 | |||||
EOS = 3 | |||||
PAD_WORD = '<blank>' | |||||
UNK_WORD = '<unk>' | |||||
BOS_WORD = '<s>' | |||||
EOS_WORD = '</s>' |