@@ -0,0 +1,57 @@ | |||
# Summarization | |||
## Extractive Summarization | |||
### Models | |||
FastNLP中实现的模型包括: | |||
1. Get To The Point: Summarization with Pointer-Generator Networks (See et al. 2017) | |||
2. Extractive Summarization with SWAP-NET : Sentences and Words from Alternating Pointer Networks (Jadhav et al. 2018) | |||
3. Searching for Effective Neural Extractive Summarization What Works and What's Next (Zhong et al. 2019) | |||
### Dataset | |||
这里提供的摘要任务数据集包括: | |||
- CNN/DailyMail | |||
- Newsroom | |||
- The New York Times Annotated Corpus | |||
- NYT | |||
- NYT50 | |||
- DUC | |||
- 2002 Task4 | |||
- 2003/2004 Task1 | |||
- arXiv | |||
- PubMed | |||
其中公开数据集(CNN/DailyMail, Newsroom, arXiv, PubMed)预处理之后的下载地址: | |||
- [百度云盘](https://pan.baidu.com) | |||
- [Google Drive](https://drive.google.com) | |||
未公开数据集(NYT, NYT50, DUC)数据处理部分脚本放置于data文件夹 | |||
### Dataset_loader | |||
- SummarizationLoader: 用于读取处理好的jsonl格式数据集,返回以下field | |||
- text: 文章正文 | |||
- summary: 摘要 | |||
- domain: 可选,文章发布网站 | |||
- tag: 可选,文章内容标签 | |||
- labels: 抽取式句子标签 | |||
### Performance and Hyperparameters | |||
## Abstractive Summarization | |||
Still in Progress... |
@@ -0,0 +1,12 @@ | |||
{ | |||
"n_layers": 16, | |||
"layer_sum": false, | |||
"layer_cat": false, | |||
"lstm_hidden_size": 300, | |||
"ffn_inner_hidden_size": 2048, | |||
"n_head": 6, | |||
"recurrent_dropout_prob": 0.1, | |||
"atten_dropout_prob": 0.1, | |||
"ffn_dropout_prob": 0.1, | |||
"fix_mask": true | |||
} |
@@ -0,0 +1,3 @@ | |||
{ | |||
} |
@@ -0,0 +1,9 @@ | |||
{ | |||
"n_layers": 12, | |||
"hidden_size": 512, | |||
"ffn_inner_hidden_size": 2048, | |||
"n_head": 8, | |||
"recurrent_dropout_prob": 0.1, | |||
"atten_dropout_prob": 0.1, | |||
"ffn_dropout_prob": 0.1 | |||
} |
@@ -0,0 +1,188 @@ | |||
import pickle | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.core.const import Const | |||
from tools.logger import * | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class SummarizationLoader(JsonLoader): | |||
""" | |||
读取summarization数据集,读取的DataSet包含fields:: | |||
text: list(str),document | |||
summary: list(str), summary | |||
text_wd: list(list(str)),tokenized document | |||
summary_wd: list(list(str)), tokenized summary | |||
labels: list(int), | |||
flatten_label: list(int), 0 or 1, flatten labels | |||
domain: str, optional | |||
tag: list(str), optional | |||
数据来源: CNN_DailyMail Newsroom DUC | |||
""" | |||
def __init__(self): | |||
super(SummarizationLoader, self).__init__() | |||
def _load(self, path): | |||
ds = super(SummarizationLoader, self)._load(path) | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
return ds | |||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True): | |||
""" | |||
:param paths: dict path for each dataset | |||
:param vocab_size: int max_size for vocab | |||
:param vocab_path: str vocab path | |||
:param sent_max_len: int max token number of the sentence | |||
:param doc_max_timesteps: int max sentence number of the document | |||
:param domain: bool build vocab for publication, use 'X' for unknown | |||
:param tag: bool build vocab for tag, use 'X' for unknown | |||
:param load_vocab: bool build vocab (False) or load vocab (True) | |||
:return: DataInfo | |||
datasets: dict keys correspond to the paths dict | |||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
embeddings: optional | |||
""" | |||
def _pad_sent(text_wd): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
datasets = {} | |||
train_ds = None | |||
for key, value in paths.items(): | |||
ds = self.load(value) | |||
# pad sent | |||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
# pad document | |||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
# rename field | |||
ds.rename_field("pad_text", Const.INPUT) | |||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||
ds.rename_field("pad_label", Const.TARGET) | |||
# set input and target | |||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
datasets[key] = ds | |||
if "train" in key: | |||
train_ds = datasets[key] | |||
vocab_dict = {} | |||
if load_vocab == False: | |||
logger.info("[INFO] Build new vocab from training dataset!") | |||
if train_ds == None: | |||
raise ValueError("Lack train file to build vocabulary!") | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
vocab_dict["vocab"] = vocabs | |||
else: | |||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
word_list = [] | |||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 0 | |||
for line in vocab_f: | |||
cnt += 1 | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
if cnt > vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
vocab_dict["vocab"] = vocabs | |||
if domain == True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(train_ds, field_name="publication") | |||
vocab_dict["domain"] = domaindict | |||
if tag == True: | |||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
tagdict.from_dataset(train_ds, field_name="tag") | |||
vocab_dict["tag"] = tagdict | |||
for ds in datasets.values(): | |||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return DataInfo(vocabs=vocab_dict, datasets=datasets) | |||
@@ -0,0 +1,566 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.nn.init as init | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
# from tools.logger import * | |||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
WORD_PAD = "[PAD]" | |||
class Encoder(nn.Module): | |||
def __init__(self, hps, embed): | |||
""" | |||
:param hps: | |||
word_emb_dim: word embedding dimension | |||
sent_max_len: max token number in the sentence | |||
output_channel: output channel for cnn | |||
min_kernel_size: min kernel size for cnn | |||
max_kernel_size: max kernel size for cnn | |||
word_embedding: bool, use word embedding or not | |||
embedding_path: word embedding path | |||
embed_train: bool, whether to train word embedding | |||
cuda: bool, use cuda or not | |||
:param vocab: FastNLP.Vocabulary | |||
""" | |||
super(Encoder, self).__init__() | |||
self._hps = hps | |||
self.sent_max_len = hps.sent_max_len | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# word embedding | |||
self.embed = embed | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
print("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
print("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def forward(self, input): | |||
# input: a batch of Example object [batch_size, N, seq_len] | |||
batch_size, N, _ = input.size() | |||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||
input_sent_len = ((input!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class DomainEncoder(Encoder): | |||
def __init__(self, hps, vocab, domaindict): | |||
super(DomainEncoder, self).__init__(hps, vocab) | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, input, domain): | |||
""" | |||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
:param domain: [batch_size] | |||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
""" | |||
batch_size, N, _ = input.size() | |||
sent_embedding = super().forward(input) | |||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class MultiDomainEncoder(Encoder): | |||
def __init__(self, hps, vocab, domaindict): | |||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||
self.domain_size = domaindict.size() | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, input, domain): | |||
""" | |||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
:param domain: [batch_size, domain_size] | |||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
""" | |||
batch_size, N, _ = input.size() | |||
# logger.info(domain[:5, :]) | |||
sent_embedding = super().forward(input) | |||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class BertEncoder(nn.Module): | |||
def __init__(self, hps): | |||
super(BertEncoder, self).__init__() | |||
from pytorch_pretrained_bert.modeling import BertModel | |||
self._hps = hps | |||
self.sent_max_len = hps.sent_max_len | |||
self._cuda = hps.cuda | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# word embedding | |||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||
self._bert.eval() | |||
for p in self._bert.parameters(): | |||
p.requires_grad = False | |||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def pad_encoder_input(self, input_list): | |||
""" | |||
:param input_list: N [seq_len, hidden_state] | |||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
""" | |||
max_len = self.sent_max_len | |||
enc_sent_input_pad = [] | |||
_, hidden_size = input_list[0].size() | |||
for i in range(len(input_list)): | |||
article_words = input_list[i] # [seq_len, hidden_size] | |||
seq_len = article_words.size(0) | |||
if seq_len > max_len: | |||
pad_words = article_words[:max_len, :] | |||
else: | |||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
enc_sent_input_pad.append(pad_words) | |||
return enc_sent_input_pad | |||
def forward(self, inputs, input_masks, enc_sent_len): | |||
""" | |||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||
:param enc_sent_len: sentence original length [batch, N] | |||
:return: | |||
""" | |||
# Use Bert to get word embedding | |||
batch_size, N = enc_sent_len.size() | |||
input_pad_list = [] | |||
for i in range(batch_size): | |||
tokens_id = inputs[i] | |||
input_mask = input_masks[i] | |||
sent_len = enc_sent_len[i] | |||
input_ids = tokens_id.unsqueeze(0) | |||
input_mask = input_mask.unsqueeze(0) | |||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||
_, hidden_size = out.size() | |||
# restore the sentence | |||
last_end = 1 | |||
enc_sent_input = [] | |||
for length in sent_len: | |||
if length != 0 and last_end < 511: | |||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||
last_end += length | |||
else: | |||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
enc_sent_input.append(pad_tensor) | |||
# pad the sentence | |||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||
input_pad = torch.stack(input_pad_list) | |||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||
sent_pos_list = [] | |||
for sentlen in enc_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class BertTagEncoder(BertEncoder): | |||
def __init__(self, hps, domaindict): | |||
super(BertTagEncoder, self).__init__(hps) | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||
batch_size, N = enc_sent_len.size() | |||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class ELMoEndoer(nn.Module): | |||
def __init__(self, hps): | |||
super(ELMoEndoer, self).__init__() | |||
self._hps = hps | |||
self.sent_max_len = hps.sent_max_len | |||
from allennlp.modules.elmo import Elmo | |||
elmo_dim = 1024 | |||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
# elmo_dim = 512 | |||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# elmo embedding | |||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def forward(self, input): | |||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||
batch_size, N, seq_len, _ = input.size() | |||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||
# logger.debug(input_sent_len.view(batch_size, -1)) | |||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||
enc_embed_input = self.embed_proj(enc_embed_input) | |||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
sent_pos_list = [] | |||
for sentlen in input_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class ELMoEndoer2(nn.Module): | |||
def __init__(self, hps): | |||
super(ELMoEndoer2, self).__init__() | |||
self._hps = hps | |||
self._cuda = hps.cuda | |||
self.sent_max_len = hps.sent_max_len | |||
from allennlp.modules.elmo import Elmo | |||
elmo_dim = 1024 | |||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
# elmo_dim = 512 | |||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# elmo embedding | |||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def pad_encoder_input(self, input_list): | |||
""" | |||
:param input_list: N [seq_len, hidden_state] | |||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
""" | |||
max_len = self.sent_max_len | |||
enc_sent_input_pad = [] | |||
_, hidden_size = input_list[0].size() | |||
for i in range(len(input_list)): | |||
article_words = input_list[i] # [seq_len, hidden_size] | |||
seq_len = article_words.size(0) | |||
if seq_len > max_len: | |||
pad_words = article_words[:max_len, :] | |||
else: | |||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
enc_sent_input_pad.append(pad_words) | |||
return enc_sent_input_pad | |||
def forward(self, inputs, input_masks, enc_sent_len): | |||
""" | |||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||
:param enc_sent_len: sentence original length [batch, N] | |||
:return: | |||
sent_embedding: [batch, N, D] | |||
""" | |||
# Use Bert to get word embedding | |||
batch_size, N = enc_sent_len.size() | |||
input_pad_list = [] | |||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||
# print("END elmo") | |||
for i in range(batch_size): | |||
sent_len = enc_sent_len[i] # [1, N] | |||
out = elmo_output[i] | |||
_, hidden_size = out.size() | |||
# restore the sentence | |||
last_end = 0 | |||
enc_sent_input = [] | |||
for length in sent_len: | |||
if length != 0 and last_end < 512: | |||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||
last_end += length | |||
else: | |||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
enc_sent_input.append(pad_tensor) | |||
# pad the sentence | |||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||
input_pad = torch.stack(input_pad_list) | |||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
sent_pos_list = [] | |||
for sentlen in enc_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding |
@@ -0,0 +1,55 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
import torch | |||
import torch.nn.functional as F | |||
from fastNLP.core.losses import LossBase | |||
from tools.logger import * | |||
class MyCrossEntropyLoss(LossBase): | |||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, mask=mask) | |||
self.padding_idx = padding_idx | |||
self.reduce = reduce | |||
def get_loss(self, pred, target, mask): | |||
""" | |||
:param pred: [batch, N, 2] | |||
:param target: [batch, N] | |||
:param input_mask: [batch, N] | |||
:return: | |||
""" | |||
# logger.debug(pred[0:5, :, :]) | |||
# logger.debug(target[0:5, :]) | |||
batch, N, _ = pred.size() | |||
pred = pred.view(-1, 2) | |||
target = target.view(-1) | |||
loss = F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx, reduction=self.reduce) | |||
loss = loss.view(batch, -1) | |||
loss = loss.masked_fill(mask.eq(0), 0) | |||
loss = loss.sum(1).mean() | |||
logger.debug("loss %f", loss) | |||
return loss | |||
@@ -0,0 +1,171 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
from __future__ import division | |||
import torch | |||
from rouge import Rouge | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.metrics import MetricBase | |||
from tools.logger import * | |||
from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | |||
class LabelFMetric(MetricBase): | |||
def __init__(self, pred=None, target=None): | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
self.match = 0.0 | |||
self.pred = 0.0 | |||
self.true = 0.0 | |||
self.match_true = 0.0 | |||
self.total = 0.0 | |||
def evaluate(self, pred, target): | |||
""" | |||
:param pred: [batch, N] int | |||
:param target: [batch, N] int | |||
:return: | |||
""" | |||
target = target.data | |||
pred = pred.data | |||
logger.debug(pred.size()) | |||
logger.debug(pred[:5,:]) | |||
batch, N = pred.size() | |||
self.pred += pred.sum() | |||
self.true += target.sum() | |||
self.match += (pred == target).sum() | |||
self.match_true += ((pred == target) & (pred == 1)).sum() | |||
self.total += batch * N | |||
def get_metric(self, reset=True): | |||
self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total | |||
logger.debug((self.match,self.pred, self.true, self.match_true, self.total)) | |||
try: | |||
accu = self.match / self.total | |||
precision = self.match_true / self.pred | |||
recall = self.match_true / self.true | |||
F = 2 * precision * recall / (precision + recall) | |||
except ZeroDivisionError: | |||
F = 0.0 | |||
logger.error("[Error] float division by zero") | |||
if reset: | |||
self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0 | |||
ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()} | |||
logger.info(ret) | |||
return ret | |||
class RougeMetric(MetricBase): | |||
def __init__(self, hps, pred=None, text=None, refer=None): | |||
super().__init__() | |||
self._hps = hps | |||
self._init_param_map(pred=pred, text=text, summary=refer) | |||
self.hyps = [] | |||
self.refers = [] | |||
def evaluate(self, pred, text, summary): | |||
""" | |||
:param prediction: [batch, N] | |||
:param text: [batch, N] | |||
:param summary: [batch, N] | |||
:return: | |||
""" | |||
batch_size, N = pred.size() | |||
for j in range(batch_size): | |||
original_article_sents = text[j] | |||
sent_max_number = len(original_article_sents) | |||
refer = "\n".join(summary[j]) | |||
hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if | |||
pred[j][id] == 1 and id < sent_max_number) | |||
if sent_max_number < self._hps.m and len(hyps) <= 1: | |||
print("sent_max_number is too short %d, Skip!", sent_max_number) | |||
continue | |||
if len(hyps) >= 1 and hyps != '.': | |||
self.hyps.append(hyps) | |||
self.refers.append(refer) | |||
elif refer == "." or refer == "": | |||
logger.error("Refer is None!") | |||
logger.debug(refer) | |||
elif hyps == "." or hyps == "": | |||
logger.error("hyps is None!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug("pred:") | |||
logger.debug(pred[j]) | |||
logger.debug(hyps) | |||
else: | |||
logger.error("Do not select any sentences!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug(original_article_sents) | |||
logger.debug(refer) | |||
continue | |||
def get_metric(self, reset=True): | |||
pass | |||
class FastRougeMetric(RougeMetric): | |||
def __init__(self, hps, pred=None, text=None, refer=None): | |||
super().__init__(hps, pred, text, refer) | |||
def get_metric(self, reset=True): | |||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||
if len(self.hyps) == 0 or len(self.refers) == 0 : | |||
logger.error("During testing, no hyps or refers is selected!") | |||
return | |||
rouge = Rouge() | |||
scores_all = rouge.get_scores(self.hyps, self.refers, avg=True) | |||
if reset: | |||
self.hyps = [] | |||
self.refers = [] | |||
logger.info(scores_all) | |||
return scores_all | |||
class PyRougeMetric(RougeMetric): | |||
def __init__(self, hps, pred=None, text=None, refer=None): | |||
super().__init__(hps, pred, text, refer) | |||
def get_metric(self, reset=True): | |||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||
if len(self.hyps) == 0 or len(self.refers) == 0: | |||
logger.error("During testing, no hyps or refers is selected!") | |||
return | |||
if isinstance(self.refers[0], list): | |||
logger.info("Multi Reference summaries!") | |||
scores_all = pyrouge_score_all_multi(self.hyps, self.refers) | |||
else: | |||
scores_all = pyrouge_score_all(self.hyps, self.refers) | |||
if reset: | |||
self.hyps = [] | |||
self.refers = [] | |||
logger.info(scores_all) | |||
return scores_all | |||
@@ -0,0 +1,144 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from .Encoder import Encoder | |||
# from tools.Encoder import Encoder | |||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
from tools.logger import * | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from transformer.Layers import EncoderLayer | |||
class TransformerModel(nn.Module): | |||
def __init__(self, hps, embed): | |||
""" | |||
:param hps: | |||
min_kernel_size: min kernel size for cnn encoder | |||
max_kernel_size: max kernel size for cnn encoder | |||
output_channel: output_channel number for cnn encoder | |||
hidden_size: hidden size for transformer | |||
n_layers: transfromer encoder layer | |||
n_head: multi head attention for transformer | |||
ffn_inner_hidden_size: FFN hiddens size | |||
atten_dropout_prob: dropout size | |||
doc_max_timesteps: max sentence number of the document | |||
:param vocab: | |||
""" | |||
super(TransformerModel, self).__init__() | |||
self._hps = hps | |||
self.encoder = Encoder(hps, embed) | |||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||
self.hidden_size = hps.hidden_size | |||
self.n_head = hps.n_head | |||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||
self.d_inner = hps.ffn_inner_hidden_size | |||
self.num_layers = hps.n_layers | |||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||
self.layer_stack = nn.ModuleList([ | |||
EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v, | |||
dropout=hps.atten_dropout_prob) | |||
for _ in range(self.num_layers)]) | |||
self.wh = nn.Linear(self.hidden_size, 2) | |||
def forward(self, words, seq_len): | |||
""" | |||
:param input: [batch_size, N, seq_len] | |||
:param input_len: [batch_size, N] | |||
:param return_atten: bool | |||
:return: | |||
""" | |||
# Sentence Encoder | |||
input = words | |||
input_len = seq_len | |||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||
input_len = input_len.float() # [batch, N] | |||
# -- Prepare masks | |||
batch_size, N = input_len.size() | |||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||
sent_pos = torch.Tensor( | |||
[np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||
enc_output_state = self.projection(self.sent_embedding) | |||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||
enc_input_list = [] | |||
for enc_layer in self.layer_stack: | |||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||
# enc_slf_attn = [n_head * batch_size, N, N] | |||
enc_input, enc_slf_atten = enc_layer(enc_input, non_pad_mask=self.non_pad_mask, | |||
slf_attn_mask=self.slf_attn_mask) | |||
enc_input_list += [enc_input] | |||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||
self.dec_output_state = self.dec_output_state.sum(0) | |||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||
idx = None | |||
if self._hps == 0: | |||
prediction = p_sent.view(-1, 2).max(1)[1] | |||
prediction = prediction.view(batch_size, -1) | |||
else: | |||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||
mask_output = mask_output * input_len.float() | |||
topk, idx = torch.topk(mask_output, self._hps.m) | |||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||
prediction = prediction.long().view(batch_size, -1) | |||
if self._hps.cuda: | |||
prediction = prediction.cuda() | |||
# logger.debug(((p_sent.size(), prediction.size(), idx.size()))) | |||
return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} | |||
@@ -0,0 +1,138 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from .Encoder import Encoder | |||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
class TransformerModel(nn.Module): | |||
def __init__(self, hps, vocab): | |||
""" | |||
:param hps: | |||
min_kernel_size: min kernel size for cnn encoder | |||
max_kernel_size: max kernel size for cnn encoder | |||
output_channel: output_channel number for cnn encoder | |||
hidden_size: hidden size for transformer | |||
n_layers: transfromer encoder layer | |||
n_head: multi head attention for transformer | |||
ffn_inner_hidden_size: FFN hiddens size | |||
atten_dropout_prob: dropout size | |||
doc_max_timesteps: max sentence number of the document | |||
:param vocab: | |||
""" | |||
super(TransformerModel, self).__init__() | |||
self._hps = hps | |||
self._vocab = vocab | |||
self.encoder = Encoder(hps, vocab) | |||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||
self.hidden_size = hps.hidden_size | |||
self.n_head = hps.n_head | |||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||
self.d_inner = hps.ffn_inner_hidden_size | |||
self.num_layers = hps.n_layers | |||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||
self.layer_stack = nn.ModuleList([ | |||
TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||
for _ in range(self.num_layers)]) | |||
self.wh = nn.Linear(self.hidden_size, 2) | |||
def forward(self, words, seq_len): | |||
""" | |||
:param input: [batch_size, N, seq_len] | |||
:param input_len: [batch_size, N] | |||
:param return_atten: bool | |||
:return: | |||
""" | |||
# Sentence Encoder | |||
input = words | |||
input_len = seq_len | |||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||
input_len = input_len.float() # [batch, N] | |||
# -- Prepare masks | |||
batch_size, N = input_len.size() | |||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||
sent_pos = torch.Tensor([np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||
enc_output_state = self.projection(self.sent_embedding) | |||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||
enc_input_list = [] | |||
for enc_layer in self.layer_stack: | |||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||
# enc_slf_attn = [n_head * batch_size, N, N] | |||
enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||
enc_input_list += [enc_input] | |||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||
self.dec_output_state = self.dec_output_state.sum(0) | |||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||
idx = None | |||
if self._hps.m == 0: | |||
prediction = p_sent.view(-1, 2).max(1)[1] | |||
prediction = prediction.view(batch_size, -1) | |||
else: | |||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||
mask_output = mask_output * input_len.float() | |||
topk, idx = torch.topk(mask_output, self._hps.m) | |||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||
prediction = prediction.long().view(batch_size, -1) | |||
if self._hps.cuda: | |||
prediction = prediction.cuda() | |||
# print((p_sent.size(), prediction.size(), idx.size())) | |||
# [batch, N, 2], [batch, N], [batch, hps.m] | |||
return {"pred": p_sent, "prediction": prediction, "pred_idx": idx} | |||
@@ -0,0 +1,24 @@ | |||
import unittest | |||
from ..data.dataloader import SummarizationLoader | |||
class TestSummarizationLoader(unittest.TestCase): | |||
def test_case1(self): | |||
sum_loader = SummarizationLoader() | |||
paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} | |||
data = sum_loader.process(paths=paths) | |||
print(data.datasets) | |||
def test_case2(self): | |||
sum_loader = SummarizationLoader() | |||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||
data = sum_loader.process(paths=paths, domain=True) | |||
print(data.datasets, data.vocabs) | |||
def test_case3(self): | |||
sum_loader = SummarizationLoader() | |||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||
data = sum_loader.process(paths=paths, tag=True) | |||
print(data.datasets, data.vocabs) |
@@ -0,0 +1,135 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
import os | |||
import sys | |||
import time | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.const import Const | |||
from fastNLP.io.model_io import ModelSaver | |||
from fastNLP.core.callback import Callback, EarlyStopError | |||
from tools.logger import * | |||
class TrainCallback(Callback): | |||
def __init__(self, hps, patience=3, quit_all=True): | |||
super().__init__() | |||
self._hps = hps | |||
self.patience = patience | |||
self.wait = 0 | |||
if type(quit_all) != bool: | |||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
self.quit_all = quit_all | |||
def on_epoch_begin(self): | |||
self.epoch_start_time = time.time() | |||
# def on_loss_begin(self, batch_y, predict_y): | |||
# """ | |||
# | |||
# :param batch_y: dict | |||
# input_len: [batch, N] | |||
# :param predict_y: dict | |||
# p_sent: [batch, N, 2] | |||
# :return: | |||
# """ | |||
# input_len = batch_y[Const.INPUT_LEN] | |||
# batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) | |||
# # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) | |||
# # logger.debug(predict_y["p_sent"][0:5,:,:]) | |||
def on_backward_begin(self, loss): | |||
""" | |||
:param loss: [] | |||
:return: | |||
""" | |||
if not (np.isfinite(loss.data)).numpy(): | |||
logger.error("train Loss is not finite. Stopping.") | |||
logger.info(loss) | |||
for name, param in self.model.named_parameters(): | |||
if param.requires_grad: | |||
logger.info(name) | |||
logger.info(param.grad.data.sum()) | |||
raise Exception("train Loss is not finite. Stopping.") | |||
def on_backward_end(self): | |||
if self._hps.grad_clip: | |||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | |||
def on_epoch_end(self): | |||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' | |||
.format(self.epoch, (time.time() - self.epoch_start_time))) | |||
def on_valid_begin(self): | |||
self.valid_start_time = time.time() | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | |||
.format(self.epoch, (time.time() - self.valid_start_time))) | |||
# early stop | |||
if not is_better_eval: | |||
if self.wait == self.patience: | |||
train_dir = os.path.join(self._hps.save_root, "train") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(self.model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
raise EarlyStopError("Early stopping raised.") | |||
else: | |||
self.wait += 1 | |||
else: | |||
self.wait = 0 | |||
# lr descent | |||
if self._hps.lr_descent: | |||
new_lr = max(5e-6, self._hps.lr / (self.epoch + 1)) | |||
for param_group in list(optimizer.param_groups): | |||
param_group['lr'] = new_lr | |||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||
def on_exception(self, exception): | |||
if isinstance(exception, KeyboardInterrupt): | |||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
train_dir = os.path.join(self._hps.save_root, "train") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(self.model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
if self.quit_all is True: | |||
sys.exit(0) # 直接退出程序 | |||
else: | |||
pass | |||
else: | |||
raise exception # 抛出陌生Error | |||
@@ -0,0 +1,562 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.autograd import * | |||
import torch.nn.init as init | |||
import data | |||
from tools.logger import * | |||
from transformer.Models import get_sinusoid_encoding_table | |||
class Encoder(nn.Module): | |||
def __init__(self, hps, vocab): | |||
super(Encoder, self).__init__() | |||
self._hps = hps | |||
self._vocab = vocab | |||
self.sent_max_len = hps.sent_max_len | |||
vocab_size = len(vocab) | |||
logger.info("[INFO] Vocabulary size is %d", vocab_size) | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# word embedding | |||
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]')) | |||
if hps.word_embedding: | |||
word2vec = data.Word_Embedding(hps.embedding_path, vocab) | |||
word_vecs = word2vec.load_my_vecs(embed_size) | |||
# pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size) | |||
pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size) | |||
pretrained_weight = np.array(pretrained_weight) | |||
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||
self.embed.weight.requires_grad = hps.embed_train | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def forward(self, input): | |||
# input: a batch of Example object [batch_size, N, seq_len] | |||
vocab = self._vocab | |||
batch_size, N, _ = input.size() | |||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||
input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1] | |||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class DomainEncoder(Encoder): | |||
def __init__(self, hps, vocab, domaindict): | |||
super(DomainEncoder, self).__init__(hps, vocab) | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, input, domain): | |||
""" | |||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
:param domain: [batch_size] | |||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
""" | |||
batch_size, N, _ = input.size() | |||
sent_embedding = super().forward(input) | |||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class MultiDomainEncoder(Encoder): | |||
def __init__(self, hps, vocab, domaindict): | |||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||
self.domain_size = domaindict.size() | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, input, domain): | |||
""" | |||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
:param domain: [batch_size, domain_size] | |||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
""" | |||
batch_size, N, _ = input.size() | |||
# logger.info(domain[:5, :]) | |||
sent_embedding = super().forward(input) | |||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class BertEncoder(nn.Module): | |||
def __init__(self, hps): | |||
super(BertEncoder, self).__init__() | |||
from pytorch_pretrained_bert.modeling import BertModel | |||
self._hps = hps | |||
self.sent_max_len = hps.sent_max_len | |||
self._cuda = hps.cuda | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# word embedding | |||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||
self._bert.eval() | |||
for p in self._bert.parameters(): | |||
p.requires_grad = False | |||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def pad_encoder_input(self, input_list): | |||
""" | |||
:param input_list: N [seq_len, hidden_state] | |||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
""" | |||
max_len = self.sent_max_len | |||
enc_sent_input_pad = [] | |||
_, hidden_size = input_list[0].size() | |||
for i in range(len(input_list)): | |||
article_words = input_list[i] # [seq_len, hidden_size] | |||
seq_len = article_words.size(0) | |||
if seq_len > max_len: | |||
pad_words = article_words[:max_len, :] | |||
else: | |||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
enc_sent_input_pad.append(pad_words) | |||
return enc_sent_input_pad | |||
def forward(self, inputs, input_masks, enc_sent_len): | |||
""" | |||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||
:param enc_sent_len: sentence original length [batch, N] | |||
:return: | |||
""" | |||
# Use Bert to get word embedding | |||
batch_size, N = enc_sent_len.size() | |||
input_pad_list = [] | |||
for i in range(batch_size): | |||
tokens_id = inputs[i] | |||
input_mask = input_masks[i] | |||
sent_len = enc_sent_len[i] | |||
input_ids = tokens_id.unsqueeze(0) | |||
input_mask = input_mask.unsqueeze(0) | |||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||
_, hidden_size = out.size() | |||
# restore the sentence | |||
last_end = 1 | |||
enc_sent_input = [] | |||
for length in sent_len: | |||
if length != 0 and last_end < 511: | |||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||
last_end += length | |||
else: | |||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
enc_sent_input.append(pad_tensor) | |||
# pad the sentence | |||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||
input_pad = torch.stack(input_pad_list) | |||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||
sent_pos_list = [] | |||
for sentlen in enc_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class BertTagEncoder(BertEncoder): | |||
def __init__(self, hps, domaindict): | |||
super(BertTagEncoder, self).__init__(hps) | |||
# domain embedding | |||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
self.domain_embedding.weight.requires_grad = True | |||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||
batch_size, N = enc_sent_len.size() | |||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
return sent_embedding | |||
class ELMoEndoer(nn.Module): | |||
def __init__(self, hps): | |||
super(ELMoEndoer, self).__init__() | |||
self._hps = hps | |||
self.sent_max_len = hps.sent_max_len | |||
from allennlp.modules.elmo import Elmo | |||
elmo_dim = 1024 | |||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
# elmo_dim = 512 | |||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# elmo embedding | |||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def forward(self, input): | |||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||
batch_size, N, seq_len, _ = input.size() | |||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||
logger.debug(input_sent_len.view(batch_size, -1)) | |||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||
enc_embed_input = self.embed_proj(enc_embed_input) | |||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
sent_pos_list = [] | |||
for sentlen in input_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding | |||
class ELMoEndoer2(nn.Module): | |||
def __init__(self, hps): | |||
super(ELMoEndoer2, self).__init__() | |||
self._hps = hps | |||
self._cuda = hps.cuda | |||
self.sent_max_len = hps.sent_max_len | |||
from allennlp.modules.elmo import Elmo | |||
elmo_dim = 1024 | |||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
# elmo_dim = 512 | |||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
embed_size = hps.word_emb_dim | |||
sent_max_len = hps.sent_max_len | |||
input_channels = 1 | |||
out_channels = hps.output_channel | |||
min_kernel_size = hps.min_kernel_size | |||
max_kernel_size = hps.max_kernel_size | |||
width = embed_size | |||
# elmo embedding | |||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
# position embedding | |||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
# cnn | |||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
logger.info("[INFO] Initing W for CNN.......") | |||
for conv in self.convs: | |||
init_weight_value = 6.0 | |||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
def calculate_fan_in_and_fan_out(tensor): | |||
dimensions = tensor.ndimension() | |||
if dimensions < 2: | |||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
if dimensions == 2: # Linear | |||
fan_in = tensor.size(1) | |||
fan_out = tensor.size(0) | |||
else: | |||
num_input_fmaps = tensor.size(1) | |||
num_output_fmaps = tensor.size(0) | |||
receptive_field_size = 1 | |||
if tensor.dim() > 2: | |||
receptive_field_size = tensor[0][0].numel() | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def pad_encoder_input(self, input_list): | |||
""" | |||
:param input_list: N [seq_len, hidden_state] | |||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
""" | |||
max_len = self.sent_max_len | |||
enc_sent_input_pad = [] | |||
_, hidden_size = input_list[0].size() | |||
for i in range(len(input_list)): | |||
article_words = input_list[i] # [seq_len, hidden_size] | |||
seq_len = article_words.size(0) | |||
if seq_len > max_len: | |||
pad_words = article_words[:max_len, :] | |||
else: | |||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
enc_sent_input_pad.append(pad_words) | |||
return enc_sent_input_pad | |||
def forward(self, inputs, input_masks, enc_sent_len): | |||
""" | |||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||
:param enc_sent_len: sentence original length [batch, N] | |||
:return: | |||
sent_embedding: [batch, N, D] | |||
""" | |||
# Use Bert to get word embedding | |||
batch_size, N = enc_sent_len.size() | |||
input_pad_list = [] | |||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||
# print("END elmo") | |||
for i in range(batch_size): | |||
sent_len = enc_sent_len[i] # [1, N] | |||
out = elmo_output[i] | |||
_, hidden_size = out.size() | |||
# restore the sentence | |||
last_end = 0 | |||
enc_sent_input = [] | |||
for length in sent_len: | |||
if length != 0 and last_end < 512: | |||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||
last_end += length | |||
else: | |||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
enc_sent_input.append(pad_tensor) | |||
# pad the sentence | |||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||
input_pad = torch.stack(input_pad_list) | |||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
sent_pos_list = [] | |||
for sentlen in enc_sent_len: | |||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
for k in range(self.sent_max_len - sentlen): | |||
sent_pos.append(0) | |||
sent_pos_list.append(sent_pos) | |||
input_pos = torch.Tensor(sent_pos_list).long() | |||
if self._hps.cuda: | |||
input_pos = input_pos.cuda() | |||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
return sent_embedding |
@@ -0,0 +1,41 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
import torch | |||
import numpy as np | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
''' Sinusoid position encoding table ''' | |||
def cal_angle(position, hid_idx): | |||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
def get_posi_angle_vec(position): | |||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
if padding_idx is not None: | |||
# zero vector for padding dimension | |||
sinusoid_table[padding_idx] = 0. | |||
return torch.FloatTensor(sinusoid_table) |
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,479 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" | |||
import os | |||
import re | |||
import glob | |||
import copy | |||
import random | |||
import json | |||
import collections | |||
from itertools import combinations | |||
import numpy as np | |||
from random import shuffle | |||
import torch.utils.data | |||
import time | |||
import pickle | |||
from nltk.tokenize import sent_tokenize | |||
import utils | |||
from logger import * | |||
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | |||
SENTENCE_START = '<s>' | |||
SENTENCE_END = '</s>' | |||
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence | |||
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words | |||
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence | |||
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences | |||
# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file. | |||
class Vocab(object): | |||
"""Vocabulary class for mapping between words and ids (integers)""" | |||
def __init__(self, vocab_file, max_size): | |||
""" | |||
Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. | |||
:param vocab_file: string; path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. | |||
:param max_size: int; The maximum size of the resulting Vocabulary. | |||
""" | |||
self._word_to_id = {} | |||
self._id_to_word = {} | |||
self._count = 0 # keeps track of total number of words in the Vocab | |||
# [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. | |||
for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]: | |||
self._word_to_id[w] = self._count | |||
self._id_to_word[self._count] = w | |||
self._count += 1 | |||
# Read the vocab file and add words up to max_size | |||
with open(vocab_file, 'r', encoding='utf8') as vocab_f: #New : add the utf8 encoding to prevent error | |||
cnt = 0 | |||
for line in vocab_f: | |||
cnt += 1 | |||
pieces = line.split("\t") | |||
# pieces = line.split() | |||
w = pieces[0] | |||
# print(w) | |||
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: | |||
raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) | |||
if w in self._word_to_id: | |||
logger.error('Duplicated word in vocabulary file Line %d : %s' % (cnt, w)) | |||
continue | |||
self._word_to_id[w] = self._count | |||
self._id_to_word[self._count] = w | |||
self._count += 1 | |||
if max_size != 0 and self._count >= max_size: | |||
logger.info("[INFO] max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) | |||
break | |||
logger.info("[INFO] Finished constructing vocabulary of %i total words. Last word added: %s", self._count, self._id_to_word[self._count-1]) | |||
def word2id(self, word): | |||
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" | |||
if word not in self._word_to_id: | |||
return self._word_to_id[UNKNOWN_TOKEN] | |||
return self._word_to_id[word] | |||
def id2word(self, word_id): | |||
"""Returns the word (string) corresponding to an id (integer).""" | |||
if word_id not in self._id_to_word: | |||
raise ValueError('Id not found in vocab: %d' % word_id) | |||
return self._id_to_word[word_id] | |||
def size(self): | |||
"""Returns the total size of the vocabulary""" | |||
return self._count | |||
def word_list(self): | |||
"""Return the word list of the vocabulary""" | |||
return self._word_to_id.keys() | |||
class Word_Embedding(object): | |||
def __init__(self, path, vocab): | |||
""" | |||
:param path: string; the path of word embedding | |||
:param vocab: object; | |||
""" | |||
logger.info("[INFO] Loading external word embedding...") | |||
self._path = path | |||
self._vocablist = vocab.word_list() | |||
self._vocab = vocab | |||
def load_my_vecs(self, k=200): | |||
"""Load word embedding""" | |||
word_vecs = {} | |||
with open(self._path, encoding="utf-8") as f: | |||
count = 0 | |||
lines = f.readlines()[1:] | |||
for line in lines: | |||
values = line.split(" ") | |||
word = values[0] | |||
count += 1 | |||
if word in self._vocablist: # whether to judge if in vocab | |||
vector = [] | |||
for count, val in enumerate(values): | |||
if count == 0: | |||
continue | |||
if count <= k: | |||
vector.append(float(val)) | |||
word_vecs[word] = vector | |||
return word_vecs | |||
def add_unknown_words_by_zero(self, word_vecs, k=200): | |||
"""Solve unknown by zeros""" | |||
zero = [0.0] * k | |||
list_word2vec = [] | |||
oov = 0 | |||
iov = 0 | |||
for i in range(self._vocab.size()): | |||
word = self._vocab.id2word(i) | |||
if word not in word_vecs: | |||
oov += 1 | |||
word_vecs[word] = zero | |||
list_word2vec.append(word_vecs[word]) | |||
else: | |||
iov += 1 | |||
list_word2vec.append(word_vecs[word]) | |||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||
return list_word2vec | |||
def add_unknown_words_by_avg(self, word_vecs, k=200): | |||
"""Solve unknown by avg word embedding""" | |||
# solve unknown words inplaced by zero list | |||
word_vecs_numpy = [] | |||
for word in self._vocablist: | |||
if word in word_vecs: | |||
word_vecs_numpy.append(word_vecs[word]) | |||
col = [] | |||
for i in range(k): | |||
sum = 0.0 | |||
for j in range(int(len(word_vecs_numpy))): | |||
sum += word_vecs_numpy[j][i] | |||
sum = round(sum, 6) | |||
col.append(sum) | |||
zero = [] | |||
for m in range(k): | |||
avg = col[m] / int(len(word_vecs_numpy)) | |||
avg = round(avg, 6) | |||
zero.append(float(avg)) | |||
list_word2vec = [] | |||
oov = 0 | |||
iov = 0 | |||
for i in range(self._vocab.size()): | |||
word = self._vocab.id2word(i) | |||
if word not in word_vecs: | |||
oov += 1 | |||
word_vecs[word] = zero | |||
list_word2vec.append(word_vecs[word]) | |||
else: | |||
iov += 1 | |||
list_word2vec.append(word_vecs[word]) | |||
logger.info("[INFO] External Word Embedding iov count: %d, oov count: %d", iov, oov) | |||
return list_word2vec | |||
def add_unknown_words_by_uniform(self, word_vecs, uniform=0.25, k=200): | |||
"""Solve unknown word by uniform(-0.25,0.25)""" | |||
list_word2vec = [] | |||
oov = 0 | |||
iov = 0 | |||
for i in range(self._vocab.size()): | |||
word = self._vocab.id2word(i) | |||
if word not in word_vecs: | |||
oov += 1 | |||
word_vecs[word] = np.random.uniform(-1 * uniform, uniform, k).round(6).tolist() | |||
list_word2vec.append(word_vecs[word]) | |||
else: | |||
iov += 1 | |||
list_word2vec.append(word_vecs[word]) | |||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||
return list_word2vec | |||
# load word embedding | |||
def load_my_vecs_freq1(self, freqs, pro): | |||
word_vecs = {} | |||
with open(self._path, encoding="utf-8") as f: | |||
freq = 0 | |||
lines = f.readlines()[1:] | |||
for line in lines: | |||
values = line.split(" ") | |||
word = values[0] | |||
if word in self._vocablist: # whehter to judge if in vocab | |||
if freqs[word] == 1: | |||
a = np.random.uniform(0, 1, 1).round(2) | |||
if pro < a: | |||
continue | |||
vector = [] | |||
for count, val in enumerate(values): | |||
if count == 0: | |||
continue | |||
vector.append(float(val)) | |||
word_vecs[word] = vector | |||
return word_vecs | |||
class DomainDict(object): | |||
"""Domain embedding for Newsroom""" | |||
def __init__(self, path): | |||
self.domain_list = self.readDomainlist(path) | |||
# self.domain_list = ["foxnews.com", "cnn.com", "mashable.com", "nytimes.com", "washingtonpost.com"] | |||
self.domain_number = len(self.domain_list) | |||
self._domain_to_id = {} | |||
self._id_to_domain = {} | |||
self._cnt = 0 | |||
self._domain_to_id["X"] = self._cnt | |||
self._id_to_domain[self._cnt] = "X" | |||
self._cnt += 1 | |||
for i in range(self.domain_number): | |||
domain = self.domain_list[i] | |||
self._domain_to_id[domain] = self._cnt | |||
self._id_to_domain[self._cnt] = domain | |||
self._cnt += 1 | |||
def readDomainlist(self, path): | |||
domain_list = [] | |||
with open(path) as f: | |||
for line in f: | |||
domain_list.append(line.split("\t")[0].strip()) | |||
logger.info(domain_list) | |||
return domain_list | |||
def domain2id(self, domain): | |||
""" Returns the id (integer) of a domain (string). Returns "X" for unknow domain. | |||
:param domain: string | |||
:return: id; int | |||
""" | |||
if domain in self.domain_list: | |||
return self._domain_to_id[domain] | |||
else: | |||
logger.info(domain) | |||
return self._domain_to_id["X"] | |||
def id2domain(self, domain_id): | |||
""" Returns the domain (string) corresponding to an id (integer). | |||
:param id: int; | |||
:return: domain: string | |||
""" | |||
if domain_id not in self._id_to_domain: | |||
raise ValueError('Id not found in DomainDict: %d' % domain_id) | |||
return self._id_to_domain[id] | |||
def size(self): | |||
return self._cnt | |||
class Example(object): | |||
"""Class representing a train/val/test example for text summarization.""" | |||
def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label, domainid=None): | |||
""" Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. | |||
:param article_sents: list of strings; one per article sentence. each token is separated by a single space. | |||
:param abstract_sents: list of strings; one per abstract sentence. In each sentence, each token is separated by a single space. | |||
:param domainid: int; publication of the example | |||
:param vocab: Vocabulary object | |||
:param sent_max_len: int; the maximum length of each sentence, padding all sentences to this length | |||
:param label: list of int; the index of selected sentences | |||
""" | |||
self.sent_max_len = sent_max_len | |||
self.enc_sent_len = [] | |||
self.enc_sent_input = [] | |||
self.enc_sent_input_pad = [] | |||
# origin_cnt = len(article_sents) | |||
# article_sents = [re.sub(r"\n+\t+", " ", sent) for sent in article_sents] | |||
# assert origin_cnt == len(article_sents) | |||
# Process the article | |||
for sent in article_sents: | |||
article_words = sent.split() | |||
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | |||
self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||
self._pad_encoder_input(vocab.word2id('[PAD]')) | |||
# Store the original strings | |||
self.original_article = " ".join(article_sents) | |||
self.original_article_sents = article_sents | |||
if isinstance(abstract_sents[0], list): | |||
logger.debug("[INFO] Multi Reference summaries!") | |||
self.original_abstract_sents = [] | |||
self.original_abstract = [] | |||
for summary in abstract_sents: | |||
self.original_abstract_sents.append([sent.strip() for sent in summary]) | |||
self.original_abstract.append("\n".join([sent.replace("\n", "") for sent in summary])) | |||
else: | |||
self.original_abstract_sents = [sent.replace("\n", "") for sent in abstract_sents] | |||
self.original_abstract = "\n".join(self.original_abstract_sents) | |||
# Store the label | |||
self.label = np.zeros(len(article_sents), dtype=int) | |||
if label != []: | |||
self.label[np.array(label)] = 1 | |||
self.label = list(self.label) | |||
# Store the publication | |||
if domainid != None: | |||
if domainid == 0: | |||
logger.debug("domain id = 0!") | |||
self.domain = domainid | |||
def _pad_encoder_input(self, pad_id): | |||
""" | |||
:param pad_id: int; token pad id | |||
:return: | |||
""" | |||
max_len = self.sent_max_len | |||
for i in range(len(self.enc_sent_input)): | |||
article_words = self.enc_sent_input[i] | |||
if len(article_words) > max_len: | |||
article_words = article_words[:max_len] | |||
while len(article_words) < max_len: | |||
article_words.append(pad_id) | |||
self.enc_sent_input_pad.append(article_words) | |||
class ExampleSet(torch.utils.data.Dataset): | |||
""" Constructor: Dataset of example(object) """ | |||
def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||
""" Initializes the ExampleSet with the path of data | |||
:param data_path: string; the path of data | |||
:param vocab: object; | |||
:param doc_max_timesteps: int; the maximum sentence number of a document, each example should pad sentences to this length | |||
:param sent_max_len: int; the maximum token number of a sentence, each sentence should pad tokens to this length | |||
:param domaindict: object; the domain dict to embed domain | |||
""" | |||
self.domaindict = domaindict | |||
if domaindict: | |||
logger.info("[INFO] Use domain information in the dateset!") | |||
if randomX==True: | |||
logger.info("[INFO] Random some example to unknow domain X!") | |||
self.randomP = 0.1 | |||
logger.info("[INFO] Start reading ExampleSet") | |||
start = time.time() | |||
self.example_list = [] | |||
self.doc_max_timesteps = doc_max_timesteps | |||
cnt = 0 | |||
with open(data_path, 'r') as reader: | |||
for line in reader: | |||
try: | |||
e = json.loads(line) | |||
article_sent = e['text'] | |||
tag = e["tag"][0] if usetag else e['publication'] | |||
# logger.info(tag) | |||
if "duc" in data_path: | |||
abstract_sent = e["summaryList"] if "summaryList" in e.keys() else [e['summary']] | |||
else: | |||
abstract_sent = e['summary'] | |||
if domaindict: | |||
if randomX == True: | |||
p = np.random.rand() | |||
if p <= self.randomP: | |||
domainid = domaindict.domain2id("X") | |||
else: | |||
domainid = domaindict.domain2id(tag) | |||
else: | |||
domainid = domaindict.domain2id(tag) | |||
else: | |||
domainid = None | |||
logger.debug((tag, domainid)) | |||
except (ValueError,EOFError) as e : | |||
logger.debug(e) | |||
break | |||
else: | |||
example = Example(article_sent, abstract_sent, vocab, sent_max_len, e["label"], domainid) # Process into an Example. | |||
self.example_list.append(example) | |||
cnt += 1 | |||
# print(cnt) | |||
logger.info("[INFO] Finish reading ExampleSet. Total time is %f, Total size is %d", time.time() - start, len(self.example_list)) | |||
self.size = len(self.example_list) | |||
# self.example_list.sort(key=lambda ex: ex.domain) | |||
def get_example(self, index): | |||
return self.example_list[index] | |||
def __getitem__(self, index): | |||
""" | |||
:param index: int; the index of the example | |||
:return | |||
input_pad: [N, seq_len] | |||
label: [N] | |||
input_mask: [N] | |||
domain: [1] | |||
""" | |||
item = self.example_list[index] | |||
input = np.array(item.enc_sent_input_pad) | |||
label = np.array(item.label, dtype=int) | |||
# pad input to doc_max_timesteps | |||
if len(input) < self.doc_max_timesteps: | |||
pad_number = self.doc_max_timesteps - len(input) | |||
pad_matrix = np.zeros((pad_number, len(input[0]))) | |||
input_pad = np.vstack((input, pad_matrix)) | |||
label = np.append(label, np.zeros(pad_number, dtype=int)) | |||
input_mask = np.append(np.ones(len(input)), np.zeros(pad_number)) | |||
else: | |||
input_pad = input[:self.doc_max_timesteps] | |||
label = label[:self.doc_max_timesteps] | |||
input_mask = np.ones(self.doc_max_timesteps) | |||
if self.domaindict: | |||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long(), item.domain | |||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long() | |||
def __len__(self): | |||
return self.size | |||
class MultiExampleSet(): | |||
def __init__(self, data_dir, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||
self.datasets = [None] * (domaindict.size() - 1) | |||
data_path_list = [os.path.join(data_dir, s) for s in os.listdir(data_dir) if s.endswith("label.jsonl")] | |||
for data_path in data_path_list: | |||
fname = data_path.split("/")[-1] # cnn.com.label.json | |||
dataname = ".".join(fname.split(".")[:-2]) | |||
domainid = domaindict.domain2id(dataname) | |||
logger.info("[INFO] domain name: %s, domain id: %d" % (dataname, domainid)) | |||
self.datasets[domainid - 1] = ExampleSet(data_path, vocab, doc_max_timesteps, sent_max_len, domaindict, randomX, usetag) | |||
def get(self, id): | |||
return self.datasets[id] | |||
from torch.utils.data.dataloader import default_collate | |||
def my_collate_fn(batch): | |||
''' | |||
:param batch: (input_pad, label, input_mask, domain) | |||
:return: | |||
''' | |||
start_domain = batch[0][-1] | |||
# for i in range(len(batch)): | |||
# print(batch[i][-1], end=',') | |||
batch = list(filter(lambda x: x[-1] == start_domain, batch)) | |||
print("start_domain %d" % start_domain) | |||
print("batch_len %d" % len(batch)) | |||
if len(batch) == 0: return torch.Tensor() | |||
return default_collate(batch) # 用默认方式拼接过滤后的batch数据 | |||
@@ -0,0 +1,27 @@ | |||
# -*- coding: utf-8 -*- | |||
import logging | |||
import sys | |||
# 获取logger实例,如果参数为空则返回root logger | |||
logger = logging.getLogger("Summarization logger") | |||
# logger = logging.getLogger() | |||
# 指定logger输出格式 | |||
formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') | |||
# # 文件日志 | |||
# file_handler = logging.FileHandler("test.log") | |||
# file_handler.setFormatter(formatter) # 可以通过setFormatter指定输出格式 | |||
# 控制台日志 | |||
console_handler = logging.StreamHandler(sys.stdout) | |||
console_handler.formatter = formatter # 也可以直接给formatter赋值 | |||
console_handler.setLevel(logging.INFO) | |||
# 为logger添加的日志处理器 | |||
# logger.addHandler(file_handler) | |||
logger.addHandler(console_handler) | |||
# 指定日志的最低输出级别,默认为WARN级别 | |||
logger.setLevel(logging.DEBUG) |
@@ -0,0 +1,297 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
import re | |||
import os | |||
import shutil | |||
import copy | |||
import datetime | |||
import numpy as np | |||
from rouge import Rouge | |||
from .logger import * | |||
# from data import * | |||
import sys | |||
sys.setrecursionlimit(10000) | |||
REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", | |||
"-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} | |||
def clean(x): | |||
return re.sub( | |||
r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", | |||
lambda m: REMAP.get(m.group()), x) | |||
def rouge_eval(hyps, refer): | |||
rouge = Rouge() | |||
# print(hyps) | |||
# print(refer) | |||
# print(rouge.get_scores(hyps, refer)) | |||
try: | |||
score = rouge.get_scores(hyps, refer)[0] | |||
mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||
except: | |||
mean_score = 0.0 | |||
return mean_score | |||
def rouge_all(hyps, refer): | |||
rouge = Rouge() | |||
score = rouge.get_scores(hyps, refer)[0] | |||
# mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||
return score | |||
def eval_label(match_true, pred, true, total, match): | |||
match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float() | |||
try: | |||
accu = match / total | |||
precision = match_true / pred | |||
recall = match_true / true | |||
F = 2 * precision * recall / (precision + recall) | |||
except ZeroDivisionError: | |||
F = 0.0 | |||
logger.error("[Error] float division by zero") | |||
return accu, precision, recall, F | |||
def pyrouge_score(hyps, refer, remap = True): | |||
from pyrouge import Rouge155 | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||
if os.path.exists(SYSTEM_PATH): | |||
shutil.rmtree(SYSTEM_PATH) | |||
os.makedirs(SYSTEM_PATH) | |||
if os.path.exists(MODEL_PATH): | |||
shutil.rmtree(MODEL_PATH) | |||
os.makedirs(MODEL_PATH) | |||
if remap == True: | |||
refer = clean(refer) | |||
hyps = clean(hyps) | |||
system_file = os.path.join(SYSTEM_PATH, 'Reference.0.txt') | |||
model_file = os.path.join(MODEL_PATH, 'Model.A.0.txt') | |||
with open(system_file, 'wb') as f: | |||
f.write(refer.encode('utf-8')) | |||
with open(model_file, 'wb') as f: | |||
f.write(hyps.encode('utf-8')) | |||
r = Rouge155('/home/dqwang/ROUGE/RELEASE-1.5.5') | |||
r.system_dir = SYSTEM_PATH | |||
r.model_dir = MODEL_PATH | |||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||
output = r.convert_and_evaluate(rouge_args="-e /home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||
output_dict = r.output_to_dict(output) | |||
shutil.rmtree(PYROUGE_ROOT) | |||
scores = {} | |||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||
return scores | |||
def pyrouge_score_all(hyps_list, refer_list, remap = True): | |||
from pyrouge import Rouge155 | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||
if os.path.exists(SYSTEM_PATH): | |||
shutil.rmtree(SYSTEM_PATH) | |||
os.makedirs(SYSTEM_PATH) | |||
if os.path.exists(MODEL_PATH): | |||
shutil.rmtree(MODEL_PATH) | |||
os.makedirs(MODEL_PATH) | |||
assert len(hyps_list) == len(refer_list) | |||
for i in range(len(hyps_list)): | |||
system_file = os.path.join(SYSTEM_PATH, 'Reference.%d.txt' % i) | |||
model_file = os.path.join(MODEL_PATH, 'Model.A.%d.txt' % i) | |||
refer = clean(refer_list[i]) if remap else refer_list[i] | |||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||
with open(system_file, 'wb') as f: | |||
f.write(refer.encode('utf-8')) | |||
with open(model_file, 'wb') as f: | |||
f.write(hyps.encode('utf-8')) | |||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||
r.system_dir = SYSTEM_PATH | |||
r.model_dir = MODEL_PATH | |||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||
output_dict = r.output_to_dict(output) | |||
shutil.rmtree(PYROUGE_ROOT) | |||
scores = {} | |||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||
return scores | |||
def pyrouge_score_all_multi(hyps_list, refer_list, remap = True): | |||
from pyrouge import Rouge155 | |||
nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT, 'system') | |||
MODEL_PATH = os.path.join(PYROUGE_ROOT, 'gold') | |||
if os.path.exists(SYSTEM_PATH): | |||
shutil.rmtree(SYSTEM_PATH) | |||
os.makedirs(SYSTEM_PATH) | |||
if os.path.exists(MODEL_PATH): | |||
shutil.rmtree(MODEL_PATH) | |||
os.makedirs(MODEL_PATH) | |||
assert len(hyps_list) == len(refer_list) | |||
for i in range(len(hyps_list)): | |||
system_file = os.path.join(SYSTEM_PATH, 'Model.%d.txt' % i) | |||
# model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i) | |||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||
with open(system_file, 'wb') as f: | |||
f.write(hyps.encode('utf-8')) | |||
referType = ["A", "B", "C", "D", "E", "F", "G"] | |||
for j in range(len(refer_list[i])): | |||
model_file = os.path.join(MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i)) | |||
refer = clean(refer_list[i][j]) if remap else refer_list[i][j] | |||
with open(model_file, 'wb') as f: | |||
f.write(refer.encode('utf-8')) | |||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||
r.system_dir = SYSTEM_PATH | |||
r.model_dir = MODEL_PATH | |||
r.system_filename_pattern = 'Model.(\d+).txt' | |||
r.model_filename_pattern = 'Reference.[A-Z].#ID#.txt' | |||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||
output_dict = r.output_to_dict(output) | |||
shutil.rmtree(PYROUGE_ROOT) | |||
scores = {} | |||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||
return scores | |||
def cal_label(article, abstract): | |||
hyps_list = article | |||
refer = abstract | |||
scores = [] | |||
for hyps in hyps_list: | |||
mean_score = rouge_eval(hyps, refer) | |||
scores.append(mean_score) | |||
selected = [] | |||
selected.append(int(np.argmax(scores))) | |||
selected_sent_cnt = 1 | |||
best_rouge = np.max(scores) | |||
while selected_sent_cnt < len(hyps_list): | |||
cur_max_rouge = 0.0 | |||
cur_max_idx = -1 | |||
for i in range(len(hyps_list)): | |||
if i not in selected: | |||
temp = copy.deepcopy(selected) | |||
temp.append(i) | |||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||
cur_rouge = rouge_eval(hyps, refer) | |||
if cur_rouge > cur_max_rouge: | |||
cur_max_rouge = cur_rouge | |||
cur_max_idx = i | |||
if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge: | |||
selected.append(cur_max_idx) | |||
selected_sent_cnt += 1 | |||
best_rouge = cur_max_rouge | |||
else: | |||
break | |||
# label = np.zeros(len(hyps_list), dtype=int) | |||
# label[np.array(selected)] = 1 | |||
# return list(label) | |||
return selected | |||
def cal_label_limited3(article, abstract): | |||
hyps_list = article | |||
refer = abstract | |||
scores = [] | |||
for hyps in hyps_list: | |||
try: | |||
mean_score = rouge_eval(hyps, refer) | |||
scores.append(mean_score) | |||
except ValueError: | |||
scores.append(0.0) | |||
selected = [] | |||
selected.append(np.argmax(scores)) | |||
selected_sent_cnt = 1 | |||
best_rouge = np.max(scores) | |||
while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3: | |||
cur_max_rouge = 0.0 | |||
cur_max_idx = -1 | |||
for i in range(len(hyps_list)): | |||
if i not in selected: | |||
temp = copy.deepcopy(selected) | |||
temp.append(i) | |||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||
cur_rouge = rouge_eval(hyps, refer) | |||
if cur_rouge > cur_max_rouge: | |||
cur_max_rouge = cur_rouge | |||
cur_max_idx = i | |||
selected.append(cur_max_idx) | |||
selected_sent_cnt += 1 | |||
best_rouge = cur_max_rouge | |||
# logger.info(selected) | |||
# label = np.zeros(len(hyps_list), dtype=int) | |||
# label[np.array(selected)] = 1 | |||
# return list(label) | |||
return selected | |||
import torch | |||
def flip(x, dim): | |||
xsize = x.size() | |||
dim = x.dim() + dim if dim < 0 else dim | |||
x = x.contiguous() | |||
x = x.view(-1, *xsize[dim:]).contiguous() | |||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, | |||
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] | |||
return x.view(xsize) | |||
def get_attn_key_pad_mask(seq_k, seq_q): | |||
''' For masking out the padding part of key sequence. ''' | |||
# Expand to fit the shape of key query attention matrix. | |||
len_q = seq_q.size(1) | |||
padding_mask = seq_k.eq(0.0) | |||
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |||
return padding_mask | |||
def get_non_pad_mask(seq): | |||
assert seq.dim() == 2 | |||
return seq.ne(0.0).type(torch.float).unsqueeze(-1) |
@@ -0,0 +1,263 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Train Model1: baseline model""" | |||
import os | |||
import sys | |||
import json | |||
import argparse | |||
import datetime | |||
import torch | |||
import torch.nn | |||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.trainer import Trainer, Tester | |||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
from tools.logger import * | |||
from data.dataloader import SummarizationLoader | |||
# from model.TransformerModel import TransformerModel | |||
from model.TForiginal import TransformerModel | |||
from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric | |||
from model.Loss import MyCrossEntropyLoss | |||
from tools.Callback import TrainCallback | |||
def setup_training(model, train_loader, valid_loader, hps): | |||
"""Does setup before starting training (run_training)""" | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
if hps.restore_model != 'None': | |||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||
loader = ModelLoader() | |||
loader.load_pytorch(model, bestmodel_file) | |||
else: | |||
logger.info("[INFO] Create new model for training...") | |||
try: | |||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||
except KeyboardInterrupt: | |||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
def run_training(model, train_loader, valid_loader, hps): | |||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||
logger.info("[INFO] Starting run_training") | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||
lr = hps.lr | |||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||
criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | |||
# criterion = torch.nn.CrossEntropyLoss(reduce="none") | |||
trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | |||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||
metric_key="f", validate_every=-1, save_path=eval_dir, | |||
callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | |||
train_info = trainer.train(load_best_model=True) | |||
logger.info(' | end of Train | time: {:5.2f}s | '.format(train_info["seconds"])) | |||
logger.info('[INFO] best eval model in epoch %d and iter %d', train_info["best_epoch"], train_info["best_step"]) | |||
logger.info(train_info["best_eval"]) | |||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
saver = ModelSaver(bestmodel_save_path) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | |||
def run_test(model, loader, hps, limited=False): | |||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||
eval_dir = os.path.join(hps.save_root, "eval") | |||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||
if not os.path.exists(eval_dir) : | |||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||
if hps.test_model == "evalbestmodel": | |||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
elif hps.test_model == "earlystop": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | |||
else: | |||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||
modelloader = ModelLoader() | |||
modelloader.load_pytorch(model, bestmodel_load_path) | |||
if hps.use_pyrouge: | |||
logger.info("[INFO] Use PyRougeMetric for testing") | |||
tester = Tester(data=loader, model=model, | |||
metrics=[LabelFMetric(pred="prediction"), PyRougeMetric(hps, pred="prediction")], | |||
batch_size=hps.batch_size) | |||
else: | |||
logger.info("[INFO] Use FastRougeMetric for testing") | |||
tester = Tester(data=loader, model=model, | |||
metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||
batch_size=hps.batch_size) | |||
test_info = tester.test() | |||
logger.info(test_info) | |||
def main(): | |||
parser = argparse.ArgumentParser(description='Summarization Model') | |||
# Where to find data | |||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||
# Important settings | |||
parser.add_argument('--mode', choices=['train', 'test'], default='train', help='must be one of train/test') | |||
parser.add_argument('--embedding', type=str, default='glove', choices=['word2vec', 'glove', 'elmo', 'bert'], help='must be one of word2vec/glove/elmo/bert') | |||
parser.add_argument('--sentence_encoder', type=str, default='transformer', choices=['bilstm', 'deeplstm', 'transformer'], help='must be one of LSTM/Transformer') | |||
parser.add_argument('--sentence_decoder', type=str, default='SeqLab', choices=['PN', 'SeqLab'], help='must be one of PN/SeqLab') | |||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||
# Where to save output | |||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||
# Hyperparameters | |||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||
# Training | |||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||
parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | |||
# test | |||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||
args = parser.parse_args() | |||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||
torch.set_printoptions(threshold=50000) | |||
# File paths | |||
DATA_FILE = args.data_path | |||
VALID_FILE = args.valid_path | |||
VOCAL_FILE = args.vocab_path | |||
LOG_PATH = args.log_root | |||
# train_log setting | |||
if not os.path.exists(LOG_PATH): | |||
if args.mode == "train": | |||
os.makedirs(LOG_PATH) | |||
else: | |||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | |||
file_handler = logging.FileHandler(log_path) | |||
file_handler.setFormatter(formatter) | |||
logger.addHandler(file_handler) | |||
logger.info("Pytorch %s", torch.__version__) | |||
sum_loader = SummarizationLoader() | |||
hps = args | |||
if hps.mode == 'test': | |||
paths = {"test": DATA_FILE} | |||
hps.recurrent_dropout_prob = 0.0 | |||
hps.atten_dropout_prob = 0.0 | |||
hps.ffn_dropout_prob = 0.0 | |||
logger.info(hps) | |||
else: | |||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||
if args.embedding == "glove": | |||
vocab = dataInfo.vocabs["vocab"] | |||
embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | |||
if hps.word_embedding: | |||
embed_loader = EmbedLoader() | |||
pretrained_weight = embed_loader.load_with_vocab(hps.embedding_path, vocab) # unfound with random init | |||
embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||
embed.weight.requires_grad = hps.embed_train | |||
else: | |||
logger.error("[ERROR] embedding To Be Continued!") | |||
sys.exit(1) | |||
if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | |||
model_param = json.load(open("config/transformer.config", "rb")) | |||
hps.__dict__.update(model_param) | |||
model = TransformerModel(hps, embed) | |||
else: | |||
logger.error("[ERROR] Model To Be Continued!") | |||
sys.exit(1) | |||
logger.info(hps) | |||
if hps.cuda: | |||
model = model.cuda() | |||
logger.info("[INFO] Use cuda") | |||
if hps.mode == 'train': | |||
dataInfo.datasets["valid"].set_target("text", "summary") | |||
setup_training(model, dataInfo.datasets["train"], dataInfo.datasets["valid"], hps) | |||
elif hps.mode == 'test': | |||
logger.info("[INFO] Decoding...") | |||
dataInfo.datasets["test"].set_target("text", "summary") | |||
run_test(model, dataInfo.datasets["test"], hps, limited=hps.limited) | |||
else: | |||
logger.error("The 'mode' flag must be one of train/eval/test") | |||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,706 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Train Model1: baseline model""" | |||
import os | |||
import sys | |||
import time | |||
import copy | |||
import pickle | |||
import datetime | |||
import argparse | |||
import logging | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.autograd import Variable | |||
from rouge import Rouge | |||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||
from fastNLP.core.batch import DataSetIter | |||
from fastNLP.core.const import Const | |||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||
from fastNLP.core.sampler import BucketSampler | |||
from tools import utils | |||
from tools.logger import * | |||
from data.dataloader import SummarizationLoader | |||
from model.TForiginal import TransformerModel | |||
def setup_training(model, train_loader, valid_loader, hps): | |||
"""Does setup before starting training (run_training)""" | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
if hps.restore_model != 'None': | |||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||
loader = ModelLoader() | |||
loader.load_pytorch(model, bestmodel_file) | |||
else: | |||
logger.info("[INFO] Create new model for training...") | |||
try: | |||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||
except KeyboardInterrupt: | |||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
def run_training(model, train_loader, valid_loader, hps): | |||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||
logger.info("[INFO] Starting run_training") | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
lr = hps.lr | |||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||
# eps=1e-09) | |||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||
best_train_loss = None | |||
best_train_F= None | |||
best_loss = None | |||
best_F = None | |||
step_num = 0 | |||
non_descent_cnt = 0 | |||
for epoch in range(1, hps.n_epochs + 1): | |||
epoch_loss = 0.0 | |||
train_loss = 0.0 | |||
total_example_num = 0 | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
epoch_start_time = time.time() | |||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||
# if i > 10: | |||
# break | |||
model.train() | |||
iter_start_time=time.time() | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
# logger.info(batch_x["text"][0]) | |||
# logger.info(input[0,:,:]) | |||
# logger.info(input_len[0:5,:]) | |||
# logger.info(batch_y["summary"][0:5]) | |||
# logger.info(label[0:5,:]) | |||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||
batch_size, N, seq_len = input.size() | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
input = Variable(input) | |||
label = Variable(label) | |||
input_len = Variable(input_len) | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
outputs = model_outputs["p_sent"].view(-1, 2) | |||
label = label.view(-1) | |||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||
# input_len = input_len.float().view(-1) | |||
loss = loss.view(batch_size, -1) | |||
loss = loss.masked_fill(input_len.eq(0), 0) | |||
loss = loss.sum(1).mean() | |||
logger.debug("loss %f", loss) | |||
if not (np.isfinite(loss.data)).numpy(): | |||
logger.error("train Loss is not finite. Stopping.") | |||
logger.info(loss) | |||
for name, param in model.named_parameters(): | |||
if param.requires_grad: | |||
logger.info(name) | |||
logger.info(param.grad.data.sum()) | |||
raise Exception("train Loss is not finite. Stopping.") | |||
optimizer.zero_grad() | |||
loss.backward() | |||
if hps.grad_clip: | |||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||
optimizer.step() | |||
step_num += 1 | |||
train_loss += float(loss.data) | |||
epoch_loss += float(loss.data) | |||
if i % 100 == 0: | |||
# start debugger | |||
# import pdb; pdb.set_trace() | |||
for name, param in model.named_parameters(): | |||
if param.requires_grad: | |||
logger.debug(name) | |||
logger.debug(param.grad.data.sum()) | |||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||
.format(i, (time.time() - iter_start_time), | |||
float(train_loss / 100))) | |||
train_loss = 0.0 | |||
# calculate the precision, recall and F | |||
prediction = outputs.max(1)[1] | |||
prediction = prediction.data | |||
label = label.data | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += int(batch_size * N) | |||
if hps.lr_descent: | |||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||
# step_num * pow(hps.warmup_steps, -1.5)) | |||
new_lr = max(5e-6, lr / (epoch + 1)) | |||
for param_group in list(optimizer.param_groups): | |||
param_group['lr'] = new_lr | |||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||
epoch_avg_loss = epoch_loss / len(train_loader) | |||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||
.format(epoch, (time.time() - epoch_start_time), | |||
float(epoch_avg_loss))) | |||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
best_train_loss = epoch_avg_loss | |||
elif epoch_avg_loss > best_train_loss: | |||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
return | |||
if not best_train_F or F > best_train_F: | |||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
best_train_F = F | |||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||
if non_descent_cnt >= 3: | |||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
return | |||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
logger.info("[INFO] Starting eval for this model ...") | |||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||
model.eval() | |||
running_loss = 0.0 | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
pairs = {} | |||
pairs["hyps"] = [] | |||
pairs["refer"] = [] | |||
total_example_num = 0 | |||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||
iter_start_time = time.time() | |||
with torch.no_grad(): | |||
for i, (batch_x, batch_y) in enumerate(loader): | |||
# if i > 10: | |||
# break | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
batch_size, N, _ = input.size() | |||
input = Variable(input, requires_grad=False) | |||
label = Variable(label) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
outputs = model_outputs["p_sent"] | |||
prediction = model_outputs["prediction"] | |||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||
label = label.view(-1) # [batch * N] | |||
loss = criterion(outputs, label) | |||
loss = loss.view(batch_size, -1) | |||
loss = loss.masked_fill(input_len.eq(0), 0) | |||
loss = loss.sum(1).mean() | |||
logger.debug("loss %f", loss) | |||
running_loss += float(loss.data) | |||
label = label.data.view(batch_size, -1) | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += batch_size * N | |||
# rouge | |||
prediction = prediction.view(batch_size, -1) | |||
for j in range(batch_size): | |||
original_article_sents = batch_x["text"][j] | |||
sent_max_number = len(original_article_sents) | |||
refer = "\n".join(batch_x["summary"][j]) | |||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||
if sent_max_number < hps.m and len(hyps) <= 1: | |||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||
continue | |||
if len(hyps) >= 1 and hyps != '.': | |||
# logger.debug(prediction[j]) | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
elif refer == "." or refer == "": | |||
logger.error("Refer is None!") | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
logger.debug(refer) | |||
elif hyps == "." or hyps == "": | |||
logger.error("hyps is None!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug("prediction:") | |||
logger.debug(prediction[j]) | |||
logger.debug(hyps) | |||
else: | |||
logger.error("Do not select any sentences!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug(original_article_sents) | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
continue | |||
running_avg_loss = running_loss / len(loader) | |||
if hps.use_pyrouge: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
logging.getLogger('global').setLevel(logging.WARNING) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
if isinstance(pairs["refer"][0], list): | |||
logger.info("Multi Reference summaries!") | |||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||
else: | |||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||
else: | |||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
rouge = Rouge() | |||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# try: | |||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# except ValueError as e: | |||
# logger.error(repr(e)) | |||
# scores_all = [] | |||
# for idx in range(len(pairs["hyps"])): | |||
# try: | |||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||
# scores_all.append(scores) | |||
# except ValueError as e: | |||
# logger.error(repr(e)) | |||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||
# finally: | |||
# logger.error("During testing, some errors happen!") | |||
# logger.error(len(scores_all)) | |||
# exit(1) | |||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||
.format((time.time() - iter_start_time), | |||
float(running_avg_loss))) | |||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||
logger.info(res) | |||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||
if best_loss is None or running_avg_loss < best_loss: | |||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
if best_loss is not None: | |||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||
else: | |||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||
saver = ModelSaver(bestmodel_save_path) | |||
saver.save_pytorch(model) | |||
best_loss = running_avg_loss | |||
non_descent_cnt = 0 | |||
else: | |||
non_descent_cnt += 1 | |||
if best_F is None or best_F < F: | |||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||
if best_F is not None: | |||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||
else: | |||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||
saver = ModelSaver(bestmodel_save_path) | |||
saver.save_pytorch(model) | |||
best_F = F | |||
return best_loss, best_F, non_descent_cnt | |||
def run_test(model, loader, hps, limited=False): | |||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||
eval_dir = os.path.join(hps.save_root, "eval") | |||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||
if not os.path.exists(eval_dir) : | |||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||
if hps.test_model == "evalbestmodel": | |||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
elif hps.test_model == "evalbestFmodel": | |||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||
elif hps.test_model == "trainbestmodel": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||
elif hps.test_model == "trainbestFmodel": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||
elif hps.test_model == "earlystop": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||
else: | |||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||
modelloader = ModelLoader() | |||
modelloader.load_pytorch(model, bestmodel_load_path) | |||
import datetime | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||
if hps.save_label: | |||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||
resfile = open(log_dir, "w") | |||
else: | |||
log_dir = os.path.join(test_dir, nowTime) | |||
resfile = open(log_dir, "wb") | |||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||
model.eval() | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
total_example_num = 0.0 | |||
pairs = {} | |||
pairs["hyps"] = [] | |||
pairs["refer"] = [] | |||
pred_list = [] | |||
iter_start_time=time.time() | |||
with torch.no_grad(): | |||
for i, (batch_x, batch_y) in enumerate(loader): | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
batch_size, N, _ = input.size() | |||
input = Variable(input) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
prediction = model_outputs["prediction"] | |||
if hps.save_label: | |||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||
continue | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += batch_size * N | |||
for j in range(batch_size): | |||
original_article_sents = batch_x["text"][j] | |||
sent_max_number = len(original_article_sents) | |||
refer = "\n".join(batch_x["summary"][j]) | |||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||
if limited: | |||
k = len(refer.split()) | |||
hyps = " ".join(hyps.split()[:k]) | |||
logger.info((len(refer.split()),len(hyps.split()))) | |||
resfile.write(b"Original_article:") | |||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b"Reference:") | |||
if isinstance(refer, list): | |||
for ref in refer: | |||
resfile.write(ref.encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b'*' * 40) | |||
resfile.write(b"\n") | |||
else: | |||
resfile.write(refer.encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b"hypothesis:") | |||
resfile.write(hyps.encode('utf-8')) | |||
resfile.write(b"\n") | |||
if hps.use_pyrouge: | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
else: | |||
try: | |||
scores = utils.rouge_all(hyps, refer) | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
except ValueError: | |||
logger.error("Do not select any sentences!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug(original_article_sents) | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
continue | |||
# single example res writer | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||
resfile.write(res.encode('utf-8')) | |||
resfile.write(b'-' * 89) | |||
resfile.write(b"\n") | |||
if hps.save_label: | |||
import json | |||
json.dump(pred_list, resfile) | |||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||
return | |||
resfile.write(b"\n") | |||
resfile.write(b'=' * 89) | |||
resfile.write(b"\n") | |||
if hps.use_pyrouge: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
if isinstance(pairs["refer"][0], list): | |||
logger.info("Multi Reference summaries!") | |||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||
else: | |||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||
else: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
rouge = Rouge() | |||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# the whole model res writer | |||
resfile.write(b"The total testset is:") | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||
resfile.write(res.encode("utf-8")) | |||
logger.info(res) | |||
logger.info(' | end of test | time: {:5.2f}s | ' | |||
.format((time.time() - iter_start_time))) | |||
# label prediction | |||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
resfile.write(res.encode('utf-8')) | |||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||
def main(): | |||
parser = argparse.ArgumentParser(description='Transformer Model') | |||
# Where to find data | |||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||
# Important settings | |||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||
# Where to save output | |||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||
# Hyperparameters | |||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||
# Training | |||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||
args = parser.parse_args() | |||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||
torch.set_printoptions(threshold=50000) | |||
hps = args | |||
# File paths | |||
DATA_FILE = args.data_path | |||
VALID_FILE = args.valid_path | |||
VOCAL_FILE = args.vocab_path | |||
LOG_PATH = args.log_root | |||
# train_log setting | |||
if not os.path.exists(LOG_PATH): | |||
if hps.mode == "train": | |||
os.makedirs(LOG_PATH) | |||
else: | |||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||
file_handler = logging.FileHandler(log_path) | |||
file_handler.setFormatter(formatter) | |||
logger.addHandler(file_handler) | |||
logger.info("Pytorch %s", torch.__version__) | |||
logger.info(args) | |||
logger.info(args) | |||
sum_loader = SummarizationLoader() | |||
if hps.mode == 'test': | |||
paths = {"test": DATA_FILE} | |||
hps.recurrent_dropout_prob = 0.0 | |||
hps.atten_dropout_prob = 0.0 | |||
hps.ffn_dropout_prob = 0.0 | |||
logger.info(hps) | |||
else: | |||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||
vocab = dataInfo.vocabs["vocab"] | |||
model = TransformerModel(hps, vocab) | |||
if len(hps.gpu) > 1: | |||
gpuid = hps.gpu.split(',') | |||
gpuid = [int(s) for s in gpuid] | |||
model = nn.DataParallel(model,device_ids=gpuid) | |||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||
if hps.cuda: | |||
model = model.cuda() | |||
logger.info("[INFO] Use cuda") | |||
if hps.mode == 'train': | |||
trainset = dataInfo.datasets["train"] | |||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||
train_batch = DataSetIter(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||
validset = dataInfo.datasets["valid"] | |||
validset.set_input("text", "summary") | |||
valid_batch = DataSetIter(batch_size=hps.batch_size, dataset=validset) | |||
setup_training(model, train_batch, valid_batch, hps) | |||
elif hps.mode == 'test': | |||
logger.info("[INFO] Decoding...") | |||
testset = dataInfo.datasets["test"] | |||
testset.set_input("text", "summary") | |||
test_batch = DataSetIter(batch_size=hps.batch_size, dataset=testset) | |||
run_test(model, test_batch, hps, limited=hps.limited) | |||
else: | |||
logger.error("The 'mode' flag must be one of train/eval/test") | |||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,705 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Train Model1: baseline model""" | |||
import os | |||
import sys | |||
import time | |||
import copy | |||
import pickle | |||
import datetime | |||
import argparse | |||
import logging | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.autograd import Variable | |||
from rouge import Rouge | |||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.const import Const | |||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||
from fastNLP.core.sampler import BucketSampler | |||
from tools import utils | |||
from tools.logger import * | |||
from data.dataloader import SummarizationLoader | |||
from model.TransformerModel import TransformerModel | |||
def setup_training(model, train_loader, valid_loader, hps): | |||
"""Does setup before starting training (run_training)""" | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
if hps.restore_model != 'None': | |||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||
loader = ModelLoader() | |||
loader.load_pytorch(model, bestmodel_file) | |||
else: | |||
logger.info("[INFO] Create new model for training...") | |||
try: | |||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||
except KeyboardInterrupt: | |||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
def run_training(model, train_loader, valid_loader, hps): | |||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||
logger.info("[INFO] Starting run_training") | |||
train_dir = os.path.join(hps.save_root, "train") | |||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
lr = hps.lr | |||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||
# eps=1e-09) | |||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||
best_train_loss = None | |||
best_train_F= None | |||
best_loss = None | |||
best_F = None | |||
step_num = 0 | |||
non_descent_cnt = 0 | |||
for epoch in range(1, hps.n_epochs + 1): | |||
epoch_loss = 0.0 | |||
train_loss = 0.0 | |||
total_example_num = 0 | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
epoch_start_time = time.time() | |||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||
# if i > 10: | |||
# break | |||
model.train() | |||
iter_start_time=time.time() | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
# logger.info(batch_x["text"][0]) | |||
# logger.info(input[0,:,:]) | |||
# logger.info(input_len[0:5,:]) | |||
# logger.info(batch_y["summary"][0:5]) | |||
# logger.info(label[0:5,:]) | |||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||
batch_size, N, seq_len = input.size() | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
input = Variable(input) | |||
label = Variable(label) | |||
input_len = Variable(input_len) | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
outputs = model_outputs[Const.OUTPUT].view(-1, 2) | |||
label = label.view(-1) | |||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||
input_len = input_len.float().view(-1) | |||
loss = loss * input_len | |||
loss = loss.view(batch_size, -1) | |||
loss = loss.sum(1).mean() | |||
if not (np.isfinite(loss.data)).numpy(): | |||
logger.error("train Loss is not finite. Stopping.") | |||
logger.info(loss) | |||
for name, param in model.named_parameters(): | |||
if param.requires_grad: | |||
logger.info(name) | |||
logger.info(param.grad.data.sum()) | |||
raise Exception("train Loss is not finite. Stopping.") | |||
optimizer.zero_grad() | |||
loss.backward() | |||
if hps.grad_clip: | |||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||
optimizer.step() | |||
step_num += 1 | |||
train_loss += float(loss.data) | |||
epoch_loss += float(loss.data) | |||
if i % 100 == 0: | |||
# start debugger | |||
# import pdb; pdb.set_trace() | |||
for name, param in model.named_parameters(): | |||
if param.requires_grad: | |||
logger.debug(name) | |||
logger.debug(param.grad.data.sum()) | |||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||
.format(i, (time.time() - iter_start_time), | |||
float(train_loss / 100))) | |||
train_loss = 0.0 | |||
# calculate the precision, recall and F | |||
prediction = outputs.max(1)[1] | |||
prediction = prediction.data | |||
label = label.data | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += int(batch_size * N) | |||
if hps.lr_descent: | |||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||
# step_num * pow(hps.warmup_steps, -1.5)) | |||
new_lr = max(5e-6, lr / (epoch + 1)) | |||
for param_group in list(optimizer.param_groups): | |||
param_group['lr'] = new_lr | |||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||
epoch_avg_loss = epoch_loss / len(train_loader) | |||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||
.format(epoch, (time.time() - epoch_start_time), | |||
float(epoch_avg_loss))) | |||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
best_train_loss = epoch_avg_loss | |||
elif epoch_avg_loss > best_train_loss: | |||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
return | |||
if not best_train_F or F > best_train_F: | |||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
best_train_F = F | |||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||
if non_descent_cnt >= 3: | |||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||
save_file = os.path.join(train_dir, "earlystop") | |||
saver = ModelSaver(save_file) | |||
saver.save_pytorch(model) | |||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||
return | |||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
logger.info("[INFO] Starting eval for this model ...") | |||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||
model.eval() | |||
running_loss = 0.0 | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
pairs = {} | |||
pairs["hyps"] = [] | |||
pairs["refer"] = [] | |||
total_example_num = 0 | |||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||
iter_start_time = time.time() | |||
with torch.no_grad(): | |||
for i, (batch_x, batch_y) in enumerate(loader): | |||
# if i > 10: | |||
# break | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
batch_size, N, _ = input.size() | |||
input = Variable(input, requires_grad=False) | |||
label = Variable(label) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
outputs = model_outputs[Const.OUTPUTS] | |||
prediction = model_outputs["prediction"] | |||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||
label = label.view(-1) # [batch * N] | |||
loss = criterion(outputs, label) | |||
input_len = input_len.float().view(-1) | |||
loss = loss * input_len | |||
loss = loss.view(batch_size, -1) | |||
loss = loss.sum(1).mean() | |||
running_loss += float(loss.data) | |||
label = label.data | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += batch_size * N | |||
# rouge | |||
prediction = prediction.view(batch_size, -1) | |||
for j in range(batch_size): | |||
original_article_sents = batch_x["text"][j] | |||
sent_max_number = len(original_article_sents) | |||
refer = "\n".join(batch_x["summary"][j]) | |||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||
if sent_max_number < hps.m and len(hyps) <= 1: | |||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||
continue | |||
if len(hyps) >= 1 and hyps != '.': | |||
# logger.debug(prediction[j]) | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
elif refer == "." or refer == "": | |||
logger.error("Refer is None!") | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
logger.debug(refer) | |||
elif hyps == "." or hyps == "": | |||
logger.error("hyps is None!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug("prediction:") | |||
logger.debug(prediction[j]) | |||
logger.debug(hyps) | |||
else: | |||
logger.error("Do not select any sentences!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug(original_article_sents) | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
continue | |||
running_avg_loss = running_loss / len(loader) | |||
if hps.use_pyrouge: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
logging.getLogger('global').setLevel(logging.WARNING) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
if isinstance(pairs["refer"][0], list): | |||
logger.info("Multi Reference summaries!") | |||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||
else: | |||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||
else: | |||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
rouge = Rouge() | |||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# try: | |||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# except ValueError as e: | |||
# logger.error(repr(e)) | |||
# scores_all = [] | |||
# for idx in range(len(pairs["hyps"])): | |||
# try: | |||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||
# scores_all.append(scores) | |||
# except ValueError as e: | |||
# logger.error(repr(e)) | |||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||
# finally: | |||
# logger.error("During testing, some errors happen!") | |||
# logger.error(len(scores_all)) | |||
# exit(1) | |||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||
.format((time.time() - iter_start_time), | |||
float(running_avg_loss))) | |||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||
logger.info(res) | |||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||
if best_loss is None or running_avg_loss < best_loss: | |||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
if best_loss is not None: | |||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||
else: | |||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||
saver = ModelSaver(bestmodel_save_path) | |||
saver.save_pytorch(model) | |||
best_loss = running_avg_loss | |||
non_descent_cnt = 0 | |||
else: | |||
non_descent_cnt += 1 | |||
if best_F is None or best_F < F: | |||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||
if best_F is not None: | |||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||
else: | |||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||
saver = ModelSaver(bestmodel_save_path) | |||
saver.save_pytorch(model) | |||
best_F = F | |||
return best_loss, best_F, non_descent_cnt | |||
def run_test(model, loader, hps, limited=False): | |||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||
eval_dir = os.path.join(hps.save_root, "eval") | |||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||
if not os.path.exists(eval_dir) : | |||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||
if hps.test_model == "evalbestmodel": | |||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||
elif hps.test_model == "evalbestFmodel": | |||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||
elif hps.test_model == "trainbestmodel": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||
elif hps.test_model == "trainbestFmodel": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||
elif hps.test_model == "earlystop": | |||
train_dir = os.path.join(hps.save_root, "train") | |||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||
else: | |||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||
modelloader = ModelLoader() | |||
modelloader.load_pytorch(model, bestmodel_load_path) | |||
import datetime | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||
if hps.save_label: | |||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||
resfile = open(log_dir, "w") | |||
else: | |||
log_dir = os.path.join(test_dir, nowTime) | |||
resfile = open(log_dir, "wb") | |||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||
model.eval() | |||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||
total_example_num = 0.0 | |||
pairs = {} | |||
pairs["hyps"] = [] | |||
pairs["refer"] = [] | |||
pred_list = [] | |||
iter_start_time=time.time() | |||
with torch.no_grad(): | |||
for i, (batch_x, batch_y) in enumerate(loader): | |||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||
label = batch_y[Const.TARGET] | |||
if hps.cuda: | |||
input = input.cuda() # [batch, N, seq_len] | |||
label = label.cuda() | |||
input_len = input_len.cuda() | |||
batch_size, N, _ = input.size() | |||
input = Variable(input) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
prediction = model_outputs["pred"] | |||
if hps.save_label: | |||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||
continue | |||
pred += prediction.sum() | |||
true += label.sum() | |||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||
match += (prediction == label).sum() | |||
total_example_num += batch_size * N | |||
for j in range(batch_size): | |||
original_article_sents = batch_x["text"][j] | |||
sent_max_number = len(original_article_sents) | |||
refer = "\n".join(batch_x["summary"][j]) | |||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||
if limited: | |||
k = len(refer.split()) | |||
hyps = " ".join(hyps.split()[:k]) | |||
logger.info((len(refer.split()),len(hyps.split()))) | |||
resfile.write(b"Original_article:") | |||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b"Reference:") | |||
if isinstance(refer, list): | |||
for ref in refer: | |||
resfile.write(ref.encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b'*' * 40) | |||
resfile.write(b"\n") | |||
else: | |||
resfile.write(refer.encode('utf-8')) | |||
resfile.write(b"\n") | |||
resfile.write(b"hypothesis:") | |||
resfile.write(hyps.encode('utf-8')) | |||
resfile.write(b"\n") | |||
if hps.use_pyrouge: | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
else: | |||
try: | |||
scores = utils.rouge_all(hyps, refer) | |||
pairs["hyps"].append(hyps) | |||
pairs["refer"].append(refer) | |||
except ValueError: | |||
logger.error("Do not select any sentences!") | |||
logger.debug("sent_max_number:%d", sent_max_number) | |||
logger.debug(original_article_sents) | |||
logger.debug("label:") | |||
logger.debug(label[j]) | |||
continue | |||
# single example res writer | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||
resfile.write(res.encode('utf-8')) | |||
resfile.write(b'-' * 89) | |||
resfile.write(b"\n") | |||
if hps.save_label: | |||
import json | |||
json.dump(pred_list, resfile) | |||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||
return | |||
resfile.write(b"\n") | |||
resfile.write(b'=' * 89) | |||
resfile.write(b"\n") | |||
if hps.use_pyrouge: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
if isinstance(pairs["refer"][0], list): | |||
logger.info("Multi Reference summaries!") | |||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||
else: | |||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||
else: | |||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||
if not len(pairs["hyps"]): | |||
logger.error("During testing, no hyps is selected!") | |||
return | |||
rouge = Rouge() | |||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||
# the whole model res writer | |||
resfile.write(b"The total testset is:") | |||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||
resfile.write(res.encode("utf-8")) | |||
logger.info(res) | |||
logger.info(' | end of test | time: {:5.2f}s | ' | |||
.format((time.time() - iter_start_time))) | |||
# label prediction | |||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||
resfile.write(res.encode('utf-8')) | |||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||
def main(): | |||
parser = argparse.ArgumentParser(description='Transformer Model') | |||
# Where to find data | |||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||
# Important settings | |||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||
# Where to save output | |||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||
# Hyperparameters | |||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||
# Training | |||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||
args = parser.parse_args() | |||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||
torch.set_printoptions(threshold=50000) | |||
hps = args | |||
# File paths | |||
DATA_FILE = args.data_path | |||
VALID_FILE = args.valid_path | |||
VOCAL_FILE = args.vocab_path | |||
LOG_PATH = args.log_root | |||
# train_log setting | |||
if not os.path.exists(LOG_PATH): | |||
if hps.mode == "train": | |||
os.makedirs(LOG_PATH) | |||
else: | |||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||
file_handler = logging.FileHandler(log_path) | |||
file_handler.setFormatter(formatter) | |||
logger.addHandler(file_handler) | |||
logger.info("Pytorch %s", torch.__version__) | |||
logger.info(args) | |||
logger.info(args) | |||
sum_loader = SummarizationLoader() | |||
if hps.mode == 'test': | |||
paths = {"test": DATA_FILE} | |||
hps.recurrent_dropout_prob = 0.0 | |||
hps.atten_dropout_prob = 0.0 | |||
hps.ffn_dropout_prob = 0.0 | |||
logger.info(hps) | |||
else: | |||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||
vocab = dataInfo.vocabs["vocab"] | |||
model = TransformerModel(hps, vocab) | |||
if len(hps.gpu) > 1: | |||
gpuid = hps.gpu.split(',') | |||
gpuid = [int(s) for s in gpuid] | |||
model = nn.DataParallel(model,device_ids=gpuid) | |||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||
if hps.cuda: | |||
model = model.cuda() | |||
logger.info("[INFO] Use cuda") | |||
if hps.mode == 'train': | |||
trainset = dataInfo.datasets["train"] | |||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||
train_batch = Batch(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||
validset = dataInfo.datasets["valid"] | |||
validset.set_input("text", "summary") | |||
valid_batch = Batch(batch_size=hps.batch_size, dataset=validset) | |||
setup_training(model, train_batch, valid_batch, hps) | |||
elif hps.mode == 'test': | |||
logger.info("[INFO] Decoding...") | |||
testset = dataInfo.datasets["test"] | |||
testset.set_input("text", "summary") | |||
test_batch = Batch(batch_size=hps.batch_size, dataset=testset) | |||
run_test(model, test_batch, hps, limited=hps.limited) | |||
else: | |||
logger.error("The 'mode' flag must be one of train/eval/test") | |||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,103 @@ | |||
""" Manage beam search info structure. | |||
Heavily borrowed from OpenNMT-py. | |||
For code in OpenNMT-py, please check the following link: | |||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py | |||
""" | |||
import torch | |||
import numpy as np | |||
import transformer.Constants as Constants | |||
class Beam(): | |||
''' Beam search ''' | |||
def __init__(self, size, device=False): | |||
self.size = size | |||
self._done = False | |||
# The score for each translation on the beam. | |||
self.scores = torch.zeros((size,), dtype=torch.float, device=device) | |||
self.all_scores = [] | |||
# The backpointers at each time-step. | |||
self.prev_ks = [] | |||
# The outputs at each time-step. | |||
self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] | |||
self.next_ys[0][0] = Constants.BOS | |||
def get_current_state(self): | |||
"Get the outputs for the current timestep." | |||
return self.get_tentative_hypothesis() | |||
def get_current_origin(self): | |||
"Get the backpointers for the current timestep." | |||
return self.prev_ks[-1] | |||
@property | |||
def done(self): | |||
return self._done | |||
def advance(self, word_prob): | |||
"Update beam status and check if finished or not." | |||
num_words = word_prob.size(1) | |||
# Sum the previous scores. | |||
if len(self.prev_ks) > 0: | |||
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) | |||
else: | |||
beam_lk = word_prob[0] | |||
flat_beam_lk = beam_lk.view(-1) | |||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort | |||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort | |||
self.all_scores.append(self.scores) | |||
self.scores = best_scores | |||
# bestScoresId is flattened as a (beam x word) array, | |||
# so we need to calculate which word and beam each score came from | |||
prev_k = best_scores_id / num_words | |||
self.prev_ks.append(prev_k) | |||
self.next_ys.append(best_scores_id - prev_k * num_words) | |||
# End condition is when top-of-beam is EOS. | |||
if self.next_ys[-1][0].item() == Constants.EOS: | |||
self._done = True | |||
self.all_scores.append(self.scores) | |||
return self._done | |||
def sort_scores(self): | |||
"Sort the scores." | |||
return torch.sort(self.scores, 0, True) | |||
def get_the_best_score_and_idx(self): | |||
"Get the score of the best in the beam." | |||
scores, ids = self.sort_scores() | |||
return scores[1], ids[1] | |||
def get_tentative_hypothesis(self): | |||
"Get the decoded sequence for the current timestep." | |||
if len(self.next_ys) == 1: | |||
dec_seq = self.next_ys[0].unsqueeze(1) | |||
else: | |||
_, keys = self.sort_scores() | |||
hyps = [self.get_hypothesis(k) for k in keys] | |||
hyps = [[Constants.BOS] + h for h in hyps] | |||
dec_seq = torch.LongTensor(hyps) | |||
return dec_seq | |||
def get_hypothesis(self, k): | |||
""" Walk back to construct the full hypothesis. """ | |||
hyp = [] | |||
for j in range(len(self.prev_ks) - 1, -1, -1): | |||
hyp.append(self.next_ys[j+1][k]) | |||
k = self.prev_ks[j][k] | |||
return list(map(lambda x: x.item(), hyp[::-1])) |
@@ -0,0 +1,10 @@ | |||
PAD = 0 | |||
UNK = 1 | |||
BOS = 2 | |||
EOS = 3 | |||
PAD_WORD = '<blank>' | |||
UNK_WORD = '<unk>' | |||
BOS_WORD = '<s>' | |||
EOS_WORD = '</s>' |
@@ -0,0 +1,49 @@ | |||
''' Define the Layers ''' | |||
import torch.nn as nn | |||
from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward | |||
__author__ = "Yu-Hsiang Huang" | |||
class EncoderLayer(nn.Module): | |||
''' Compose with two layers ''' | |||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||
super(EncoderLayer, self).__init__() | |||
self.slf_attn = MultiHeadAttention( | |||
n_head, d_model, d_k, d_v, dropout=dropout) | |||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): | |||
enc_output, enc_slf_attn = self.slf_attn( | |||
enc_input, enc_input, enc_input, mask=slf_attn_mask) | |||
enc_output *= non_pad_mask | |||
enc_output = self.pos_ffn(enc_output) | |||
enc_output *= non_pad_mask | |||
return enc_output, enc_slf_attn | |||
class DecoderLayer(nn.Module): | |||
''' Compose with three layers ''' | |||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||
super(DecoderLayer, self).__init__() | |||
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |||
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |||
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): | |||
dec_output, dec_slf_attn = self.slf_attn( | |||
dec_input, dec_input, dec_input, mask=slf_attn_mask) | |||
dec_output *= non_pad_mask | |||
dec_output, dec_enc_attn = self.enc_attn( | |||
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) | |||
dec_output *= non_pad_mask | |||
dec_output = self.pos_ffn(dec_output) | |||
dec_output *= non_pad_mask | |||
return dec_output, dec_slf_attn, dec_enc_attn |
@@ -0,0 +1,208 @@ | |||
''' Define the Transformer model ''' | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
import transformer.Constants as Constants | |||
from transformer.Layers import EncoderLayer, DecoderLayer | |||
__author__ = "Yu-Hsiang Huang" | |||
def get_non_pad_mask(seq): | |||
assert seq.dim() == 2 | |||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
''' Sinusoid position encoding table ''' | |||
def cal_angle(position, hid_idx): | |||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
def get_posi_angle_vec(position): | |||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
if padding_idx is not None: | |||
# zero vector for padding dimension | |||
sinusoid_table[padding_idx] = 0. | |||
return torch.FloatTensor(sinusoid_table) | |||
def get_attn_key_pad_mask(seq_k, seq_q): | |||
''' For masking out the padding part of key sequence. ''' | |||
# Expand to fit the shape of key query attention matrix. | |||
len_q = seq_q.size(1) | |||
padding_mask = seq_k.eq(Constants.PAD) | |||
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |||
return padding_mask | |||
def get_subsequent_mask(seq): | |||
''' For masking out the subsequent info. ''' | |||
sz_b, len_s = seq.size() | |||
subsequent_mask = torch.triu( | |||
torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) | |||
subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls | |||
return subsequent_mask | |||
class Encoder(nn.Module): | |||
''' A encoder model with self attention mechanism. ''' | |||
def __init__( | |||
self, | |||
n_src_vocab, len_max_seq, d_word_vec, | |||
n_layers, n_head, d_k, d_v, | |||
d_model, d_inner, dropout=0.1): | |||
super().__init__() | |||
n_position = len_max_seq + 1 | |||
self.src_word_emb = nn.Embedding( | |||
n_src_vocab, d_word_vec, padding_idx=Constants.PAD) | |||
self.position_enc = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), | |||
freeze=True) | |||
self.layer_stack = nn.ModuleList([ | |||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||
for _ in range(n_layers)]) | |||
def forward(self, src_seq, src_pos, return_attns=False): | |||
enc_slf_attn_list = [] | |||
# -- Prepare masks | |||
slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) | |||
non_pad_mask = get_non_pad_mask(src_seq) | |||
# -- Forward | |||
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) | |||
for enc_layer in self.layer_stack: | |||
enc_output, enc_slf_attn = enc_layer( | |||
enc_output, | |||
non_pad_mask=non_pad_mask, | |||
slf_attn_mask=slf_attn_mask) | |||
if return_attns: | |||
enc_slf_attn_list += [enc_slf_attn] | |||
if return_attns: | |||
return enc_output, enc_slf_attn_list | |||
return enc_output, | |||
class Decoder(nn.Module): | |||
''' A decoder model with self attention mechanism. ''' | |||
def __init__( | |||
self, | |||
n_tgt_vocab, len_max_seq, d_word_vec, | |||
n_layers, n_head, d_k, d_v, | |||
d_model, d_inner, dropout=0.1): | |||
super().__init__() | |||
n_position = len_max_seq + 1 | |||
self.tgt_word_emb = nn.Embedding( | |||
n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) | |||
self.position_enc = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), | |||
freeze=True) | |||
self.layer_stack = nn.ModuleList([ | |||
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||
for _ in range(n_layers)]) | |||
def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): | |||
dec_slf_attn_list, dec_enc_attn_list = [], [] | |||
# -- Prepare masks | |||
non_pad_mask = get_non_pad_mask(tgt_seq) | |||
slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) | |||
slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) | |||
slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) | |||
dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) | |||
# -- Forward | |||
dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) | |||
for dec_layer in self.layer_stack: | |||
dec_output, dec_slf_attn, dec_enc_attn = dec_layer( | |||
dec_output, enc_output, | |||
non_pad_mask=non_pad_mask, | |||
slf_attn_mask=slf_attn_mask, | |||
dec_enc_attn_mask=dec_enc_attn_mask) | |||
if return_attns: | |||
dec_slf_attn_list += [dec_slf_attn] | |||
dec_enc_attn_list += [dec_enc_attn] | |||
if return_attns: | |||
return dec_output, dec_slf_attn_list, dec_enc_attn_list | |||
return dec_output, | |||
class Transformer(nn.Module): | |||
''' A sequence to sequence model with attention mechanism. ''' | |||
def __init__( | |||
self, | |||
n_src_vocab, n_tgt_vocab, len_max_seq, | |||
d_word_vec=512, d_model=512, d_inner=2048, | |||
n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, | |||
tgt_emb_prj_weight_sharing=True, | |||
emb_src_tgt_weight_sharing=True): | |||
super().__init__() | |||
self.encoder = Encoder( | |||
n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, | |||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, | |||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, | |||
dropout=dropout) | |||
self.decoder = Decoder( | |||
n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, | |||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, | |||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, | |||
dropout=dropout) | |||
self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) | |||
nn.init.xavier_normal_(self.tgt_word_prj.weight) | |||
assert d_model == d_word_vec, \ | |||
'To facilitate the residual connections, \ | |||
the dimensions of all module outputs shall be the same.' | |||
if tgt_emb_prj_weight_sharing: | |||
# Share the weight matrix between target word embedding & the final logit dense layer | |||
self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight | |||
self.x_logit_scale = (d_model ** -0.5) | |||
else: | |||
self.x_logit_scale = 1. | |||
if emb_src_tgt_weight_sharing: | |||
# Share the weight matrix between source & target word embeddings | |||
assert n_src_vocab == n_tgt_vocab, \ | |||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | |||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | |||
tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] | |||
enc_output, *_ = self.encoder(src_seq, src_pos) | |||
dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) | |||
seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale | |||
return seq_logit.view(-1, seq_logit.size(2)) |
@@ -0,0 +1,28 @@ | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
__author__ = "Yu-Hsiang Huang" | |||
class ScaledDotProductAttention(nn.Module): | |||
''' Scaled Dot-Product Attention ''' | |||
def __init__(self, temperature, attn_dropout=0.1): | |||
super().__init__() | |||
self.temperature = temperature | |||
self.dropout = nn.Dropout(attn_dropout) | |||
self.softmax = nn.Softmax(dim=2) | |||
def forward(self, q, k, v, mask=None): | |||
attn = torch.bmm(q, k.transpose(1, 2)) | |||
attn = attn / self.temperature | |||
if mask is not None: | |||
attn = attn.masked_fill(mask, -np.inf) | |||
attn = self.softmax(attn) | |||
attn = self.dropout(attn) | |||
output = torch.bmm(attn, v) | |||
return output, attn |
@@ -0,0 +1,35 @@ | |||
'''A wrapper class for optimizer ''' | |||
import numpy as np | |||
class ScheduledOptim(): | |||
'''A simple wrapper class for learning rate scheduling''' | |||
def __init__(self, optimizer, d_model, n_warmup_steps): | |||
self._optimizer = optimizer | |||
self.n_warmup_steps = n_warmup_steps | |||
self.n_current_steps = 0 | |||
self.init_lr = np.power(d_model, -0.5) | |||
def step_and_update_lr(self): | |||
"Step with the inner optimizer" | |||
self._update_learning_rate() | |||
self._optimizer.step() | |||
def zero_grad(self): | |||
"Zero out the gradients by the inner optimizer" | |||
self._optimizer.zero_grad() | |||
def _get_lr_scale(self): | |||
return np.min([ | |||
np.power(self.n_current_steps, -0.5), | |||
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) | |||
def _update_learning_rate(self): | |||
''' Learning rate scheduling per step ''' | |||
self.n_current_steps += 1 | |||
lr = self.init_lr * self._get_lr_scale() | |||
for param_group in self._optimizer.param_groups: | |||
param_group['lr'] = lr | |||
@@ -0,0 +1,82 @@ | |||
''' Define the sublayers in encoder/decoder layer ''' | |||
import numpy as np | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from transformer.Modules import ScaledDotProductAttention | |||
__author__ = "Yu-Hsiang Huang" | |||
class MultiHeadAttention(nn.Module): | |||
''' Multi-Head Attention module ''' | |||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |||
super().__init__() | |||
self.n_head = n_head | |||
self.d_k = d_k | |||
self.d_v = d_v | |||
self.w_qs = nn.Linear(d_model, n_head * d_k) | |||
self.w_ks = nn.Linear(d_model, n_head * d_k) | |||
self.w_vs = nn.Linear(d_model, n_head * d_v) | |||
nn.init.xavier_normal_(self.w_qs.weight) | |||
nn.init.xavier_normal_(self.w_ks.weight) | |||
nn.init.xavier_normal_(self.w_vs.weight) | |||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) | |||
self.layer_norm = nn.LayerNorm(d_model) | |||
self.fc = nn.Linear(n_head * d_v, d_model) | |||
nn.init.xavier_normal_(self.fc.weight) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, q, k, v, mask=None): | |||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |||
sz_b, len_q, _ = q.size() | |||
sz_b, len_k, _ = k.size() | |||
sz_b, len_v, _ = v.size() | |||
residual = q | |||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) | |||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk | |||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk | |||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv | |||
if mask is not None: | |||
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. | |||
output, attn = self.attention(q, k, v, mask=mask) | |||
output = output.view(n_head, sz_b, len_q, d_v) | |||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) | |||
output = self.dropout(self.fc(output)) | |||
output = self.layer_norm(output + residual) | |||
return output, attn | |||
class PositionwiseFeedForward(nn.Module): | |||
''' A two-feed-forward-layer module ''' | |||
def __init__(self, d_in, d_hid, dropout=0.1): | |||
super().__init__() | |||
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise | |||
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise | |||
self.layer_norm = nn.LayerNorm(d_in) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, x): | |||
residual = x | |||
output = x.transpose(1, 2) | |||
output = self.w_2(F.relu(self.w_1(output))) | |||
output = output.transpose(1, 2) | |||
output = self.dropout(output) | |||
output = self.layer_norm(output + residual) | |||
return output |
@@ -0,0 +1,166 @@ | |||
''' This module will handle the text generation with beam search. ''' | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from transformer.Models import Transformer | |||
from transformer.Beam import Beam | |||
class Translator(object): | |||
''' Load with trained model and handle the beam search ''' | |||
def __init__(self, opt): | |||
self.opt = opt | |||
self.device = torch.device('cuda' if opt.cuda else 'cpu') | |||
checkpoint = torch.load(opt.model) | |||
model_opt = checkpoint['settings'] | |||
self.model_opt = model_opt | |||
model = Transformer( | |||
model_opt.src_vocab_size, | |||
model_opt.tgt_vocab_size, | |||
model_opt.max_token_seq_len, | |||
tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, | |||
emb_src_tgt_weight_sharing=model_opt.embs_share_weight, | |||
d_k=model_opt.d_k, | |||
d_v=model_opt.d_v, | |||
d_model=model_opt.d_model, | |||
d_word_vec=model_opt.d_word_vec, | |||
d_inner=model_opt.d_inner_hid, | |||
n_layers=model_opt.n_layers, | |||
n_head=model_opt.n_head, | |||
dropout=model_opt.dropout) | |||
model.load_state_dict(checkpoint['model']) | |||
print('[Info] Trained model state loaded.') | |||
model.word_prob_prj = nn.LogSoftmax(dim=1) | |||
model = model.to(self.device) | |||
self.model = model | |||
self.model.eval() | |||
def translate_batch(self, src_seq, src_pos): | |||
''' Translation work in one batch ''' | |||
def get_inst_idx_to_tensor_position_map(inst_idx_list): | |||
''' Indicate the position of an instance in a tensor. ''' | |||
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} | |||
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): | |||
''' Collect tensor parts associated to active instances. ''' | |||
_, *d_hs = beamed_tensor.size() | |||
n_curr_active_inst = len(curr_active_inst_idx) | |||
new_shape = (n_curr_active_inst * n_bm, *d_hs) | |||
beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) | |||
beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) | |||
beamed_tensor = beamed_tensor.view(*new_shape) | |||
return beamed_tensor | |||
def collate_active_info( | |||
src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): | |||
# Sentences which are still active are collected, | |||
# so the decoder will not run on completed sentences. | |||
n_prev_active_inst = len(inst_idx_to_position_map) | |||
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] | |||
active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) | |||
active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) | |||
active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) | |||
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) | |||
return active_src_seq, active_src_enc, active_inst_idx_to_position_map | |||
def beam_decode_step( | |||
inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): | |||
''' Decode and update beam status, and then return active beam idx ''' | |||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): | |||
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] | |||
dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) | |||
dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) | |||
return dec_partial_seq | |||
def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): | |||
dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) | |||
dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) | |||
return dec_partial_pos | |||
def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): | |||
dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) | |||
dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h | |||
word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) | |||
word_prob = word_prob.view(n_active_inst, n_bm, -1) | |||
return word_prob | |||
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): | |||
active_inst_idx_list = [] | |||
for inst_idx, inst_position in inst_idx_to_position_map.items(): | |||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) | |||
if not is_inst_complete: | |||
active_inst_idx_list += [inst_idx] | |||
return active_inst_idx_list | |||
n_active_inst = len(inst_idx_to_position_map) | |||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) | |||
dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) | |||
word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) | |||
# Update the beam with predicted word prob information and collect incomplete instances | |||
active_inst_idx_list = collect_active_inst_idx_list( | |||
inst_dec_beams, word_prob, inst_idx_to_position_map) | |||
return active_inst_idx_list | |||
def collect_hypothesis_and_scores(inst_dec_beams, n_best): | |||
all_hyp, all_scores = [], [] | |||
for inst_idx in range(len(inst_dec_beams)): | |||
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() | |||
all_scores += [scores[:n_best]] | |||
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] | |||
all_hyp += [hyps] | |||
return all_hyp, all_scores | |||
with torch.no_grad(): | |||
#-- Encode | |||
src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) | |||
src_enc, *_ = self.model.encoder(src_seq, src_pos) | |||
#-- Repeat data for beam search | |||
n_bm = self.opt.beam_size | |||
n_inst, len_s, d_h = src_enc.size() | |||
src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) | |||
src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) | |||
#-- Prepare beams | |||
inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] | |||
#-- Bookkeeping for active or not | |||
active_inst_idx_list = list(range(n_inst)) | |||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) | |||
#-- Decode | |||
for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): | |||
active_inst_idx_list = beam_decode_step( | |||
inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) | |||
if not active_inst_idx_list: | |||
break # all instances have finished their path to <EOS> | |||
src_seq, src_enc, inst_idx_to_position_map = collate_active_info( | |||
src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) | |||
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) | |||
return batch_hyp, batch_scores |
@@ -0,0 +1,13 @@ | |||
import transformer.Constants | |||
import transformer.Modules | |||
import transformer.Layers | |||
import transformer.SubLayers | |||
import transformer.Models | |||
import transformer.Translator | |||
import transformer.Beam | |||
import transformer.Optim | |||
__all__ = [ | |||
transformer.Constants, transformer.Modules, transformer.Layers, | |||
transformer.SubLayers, transformer.Models, transformer.Optim, | |||
transformer.Translator, transformer.Beam] |