From 610791ad78e72dea467339fb2326b4da2755cb86 Mon Sep 17 00:00:00 2001 From: Danqing Wang Date: Mon, 1 Jul 2019 11:39:02 +0800 Subject: [PATCH] update Readme.md --- reproduction/Summarization/README.md | 12 +- reproduction/Summarization/data/dataloader.py | 4 +- reproduction/Summarization/model/DeepLSTM.py | 136 ++++++++++++++++++ reproduction/Summarization/model/LSTMModel.py | 103 +++++++++++++ reproduction/Summarization/model/Metric.py | 4 +- .../Summarization/model/TForiginal.py | 5 +- 6 files changed, 253 insertions(+), 11 deletions(-) create mode 100644 reproduction/Summarization/model/DeepLSTM.py create mode 100644 reproduction/Summarization/model/LSTMModel.py diff --git a/reproduction/Summarization/README.md b/reproduction/Summarization/README.md index 6431e736..a988fd45 100644 --- a/reproduction/Summarization/README.md +++ b/reproduction/Summarization/README.md @@ -8,8 +8,7 @@ FastNLP中实现的模型包括: 1. Get To The Point: Summarization with Pointer-Generator Networks (See et al. 2017) -2. Extractive Summarization with SWAP-NET : Sentences and Words from Alternating Pointer Networks (Jadhav et al. 2018) -3. Searching for Effective Neural Extractive Summarization What Works and What's Next (Zhong et al. 2019) +2. Searching for Effective Neural Extractive Summarization What Works and What's Next (Zhong et al. 2019) @@ -32,8 +31,8 @@ FastNLP中实现的模型包括: 其中公开数据集(CNN/DailyMail, Newsroom, arXiv, PubMed)预处理之后的下载地址: -- [百度云盘](https://pan.baidu.com) -- [Google Drive](https://drive.google.com) +- [百度云盘](https://pan.baidu.com/s/11qWnDjK9lb33mFZ9vuYlzA) (提取码:h1px) +- [Google Drive](https://drive.google.com/file/d/1uzeSdcLk5ilHaUTeJRNrf-_j59CQGe6r/view?usp=drivesdk) 未公开数据集(NYT, NYT50, DUC)数据处理部分脚本放置于data文件夹 @@ -53,5 +52,10 @@ FastNLP中实现的模型包括: ### Performance and Hyperparameters +| Model | ROUGE-1 | ROUGE-2 | ROUGE-L | Paper | +See + + + ## Abstractive Summarization Still in Progress... \ No newline at end of file diff --git a/reproduction/Summarization/data/dataloader.py b/reproduction/Summarization/data/dataloader.py index 55688c4d..fe787c31 100644 --- a/reproduction/Summarization/data/dataloader.py +++ b/reproduction/Summarization/data/dataloader.py @@ -158,11 +158,11 @@ class SummarizationLoader(JsonLoader): logger.info("[INFO] Load existing vocab from %s!" % vocab_path) word_list = [] with open(vocab_path, 'r', encoding='utf8') as vocab_f: - cnt = 0 + cnt = 2 # pad and unk for line in vocab_f: - cnt += 1 pieces = line.split("\t") word_list.append(pieces[0]) + cnt += 1 if cnt > vocab_size: break vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) diff --git a/reproduction/Summarization/model/DeepLSTM.py b/reproduction/Summarization/model/DeepLSTM.py new file mode 100644 index 00000000..80842bf7 --- /dev/null +++ b/reproduction/Summarization/model/DeepLSTM.py @@ -0,0 +1,136 @@ +import numpy as np + + +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +from torch.autograd import Variable +from torch.distributions import Bernoulli + +class DeepLSTM(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True): + super(DeepLSTM, self).__init__() + + self.fix_mask = fix_mask + self.use_cuda = use_cuda + self.input_size = input_size + self.num_layers = num_layers + self.hidden_size = hidden_size + self.recurrent_dropout = recurrent_dropout + + self.lstms = nn.ModuleList([None] * self.num_layers) + self.highway_gate_input = nn.ModuleList([None] * self.num_layers) + self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers) + self.highway_linear_input = nn.ModuleList([None] * self.num_layers) + + # self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size)) + # init.xavier_normal_(self._input_w) + + for l in range(self.num_layers): + input_dim = input_size if l == 0 else hidden_size + + self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size) + self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size) + self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False) + + # logger.info("[INFO] Initing W for LSTM .......") + for l in range(self.num_layers): + if use_orthnormal_init: + # logger.info("[INFO] Initing W using orthnormal init .......") + init.orthogonal_(self.lstms[l].weight_ih) + init.orthogonal_(self.lstms[l].weight_hh) + init.orthogonal_(self.highway_gate_input[l].weight.data) + init.orthogonal_(self.highway_gate_state[l].weight.data) + init.orthogonal_(self.highway_linear_input[l].weight.data) + else: + # logger.info("[INFO] Initing W using xavier_normal .......") + init_weight_value = 6.0 + init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value)) + init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value)) + init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value)) + init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value)) + init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value)) + + def init_hidden(self, batch_size, hidden_size): + # the first is the hidden h + # the second is the cell c + if self.use_cuda: + return (torch.zeros(batch_size, hidden_size).cuda(), + torch.zeros(batch_size, hidden_size).cuda()) + else: + return (torch.zeros(batch_size, hidden_size), + torch.zeros(batch_size, hidden_size)) + + def forward(self, inputs, input_masks, Train): + + ''' + inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) + input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) + ''' + + batch_size, seq_len = inputs[0].size(1), inputs[0].size(0) + + # inputs[0] = torch.matmul(inputs[0], self._input_w) + # input_masks[0] = input_masks[0].unsqueeze(-1).expand(seq_len, batch_size, self.hidden_size) + + self.inputs = inputs + self.input_masks = input_masks + + if self.fix_mask: + self.output_dropout_layers = [None] * self.num_layers + for l in range(self.num_layers): + binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout + # This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. + # from allennlp.nn.util.py + self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout) + if self.use_cuda: + self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda() + + for l in range(self.num_layers): + h, c = self.init_hidden(batch_size, self.hidden_size) + outputs_list = [] + for t in range(len(self.inputs[l])): + x = self.inputs[l][t] + m = self.input_masks[l][t].float() + h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size] + r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h)) + lx = self.highway_linear_input[l](x) # [batch, hidden_size] + h_temp = r * h_temp + (1 - r) * lx + + if Train: + if self.fix_mask: + h_temp = self.output_dropout_layers[l] * h_temp + else: + h_temp = F.dropout(h_temp, p=self.recurrent_dropout) + + h = m * h_temp + (1 - m) * h + c = m * c_temp + (1 - m) * c + outputs_list.append(h) + outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size] + self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size] + self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0) + + self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size] + + # flip -2 layer + # self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0) + + # concat last two layer + # self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1) + + self.output_state = self.output_state[-1].transpose(0, 1) + + assert self.output_state.size() == (batch_size, seq_len, self.hidden_size) + + return self.output_state + + @staticmethod + def flip(x, dim): + xsize = x.size() + dim = x.dim() + dim if dim < 0 else dim + x = x.contiguous() + x = x.view(-1, *xsize[dim:]).contiguous() + x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, + -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] + return x.view(xsize) diff --git a/reproduction/Summarization/model/LSTMModel.py b/reproduction/Summarization/model/LSTMModel.py new file mode 100644 index 00000000..1fae03dd --- /dev/null +++ b/reproduction/Summarization/model/LSTMModel.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.autograd import * +from torch.distributions import * + +from .Encoder import Encoder +from .DeepLSTM import DeepLSTM + +from transformer.SubLayers import MultiHeadAttention,PositionwiseFeedForward + +class SummarizationModel(nn.Module): + def __init__(self, hps, embed): + """ + + :param hps: hyperparameters for the model + :param vocab: vocab object + """ + super(SummarizationModel, self).__init__() + + self._hps = hps + + # sentence encoder + self.encoder = Encoder(hps, embed) + + # Multi-layer highway lstm + self.num_layers = hps.n_layers + self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel + self.lstm_hidden_size = hps.lstm_hidden_size + self.recurrent_dropout = hps.recurrent_dropout_prob + + self.deep_lstm = DeepLSTM(self.sent_embedding_size, self.lstm_hidden_size, self.num_layers, self.recurrent_dropout, + hps.use_orthnormal_init, hps.fix_mask, hps.cuda) + + # Multi-head attention + self.n_head = hps.n_head + self.d_v = self.d_k = int(self.lstm_hidden_size / hps.n_head) + self.d_inner = hps.ffn_inner_hidden_size + self.slf_attn = MultiHeadAttention(hps.n_head, self.lstm_hidden_size , self.d_k, self.d_v, dropout=hps.atten_dropout_prob) + self.pos_ffn = PositionwiseFeedForward(self.d_v, self.d_inner, dropout = hps.ffn_dropout_prob) + + self.wh = nn.Linear(self.d_v, 2) + + + def forward(self, input, input_len, Train): + """ + + :param input: [batch_size, N, seq_len], word idx long tensor + :param input_len: [batch_size, N], 1 for sentence and 0 for padding + :param Train: True for train and False for eval and test + :param return_atten: True or False to return multi-head attention output self.output_slf_attn + :return: + p_sent: [batch_size, N, 2] + output_slf_attn: (option) [n_head, batch_size, N, N] + """ + + # -- Sentence Encoder + self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] + + # -- Multi-layer highway lstm + input_len = input_len.float() # [batch, N] + self.inputs = [None] * (self.num_layers + 1) + self.input_masks = [None] * (self.num_layers + 1) + self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] + self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) + + self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size] + + # -- Prepare masks + batch_size, N = input_len.size() + slf_attn_mask = input_len.eq(0.0) # [batch, N], 1 for padding + slf_attn_mask = slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] + + # -- Multi-head attention + self.atten_output, self.output_slf_attn = self.slf_attn(self.lstm_output_state, self.lstm_output_state, self.lstm_output_state, mask=slf_attn_mask) + self.atten_output *= input_len.unsqueeze(2) # [batch_size, N, lstm_hidden_size = (n_head * d_v)] + self.multi_atten_output = self.atten_output.view(batch_size, N, self.n_head, self.d_v) # [batch_size, N, n_head, d_v] + self.multi_atten_context = self.multi_atten_output[:, :, 0::2, :].sum(2) - self.multi_atten_output[:, :, 1::2, :].sum(2) # [batch_size, N, d_v] + + # -- Position-wise Feed-Forward Networks + self.output_state = self.pos_ffn(self.multi_atten_context) + self.output_state = self.output_state * input_len.unsqueeze(2) # [batch_size, N, d_v] + + p_sent = self.wh(self.output_state) # [batch, N, 2] + + idx = None + if self._hps.m == 0: + prediction = p_sent.view(-1, 2).max(1)[1] + prediction = prediction.view(batch_size, -1) + else: + mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] + mask_output = mask_output.masked_fill(input_len.eq(0), 0) + topk, idx = torch.topk(mask_output, self._hps.m) + prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) + prediction = prediction.long().view(batch_size, -1) + + if self._hps.cuda: + prediction = prediction.cuda() + + return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} diff --git a/reproduction/Summarization/model/Metric.py b/reproduction/Summarization/model/Metric.py index 54b333f8..441c27b1 100644 --- a/reproduction/Summarization/model/Metric.py +++ b/reproduction/Summarization/model/Metric.py @@ -50,8 +50,8 @@ class LabelFMetric(MetricBase): """ target = target.data pred = pred.data - logger.debug(pred.size()) - logger.debug(pred[:5,:]) + # logger.debug(pred.size()) + # logger.debug(pred[:5,:]) batch, N = pred.size() self.pred += pred.sum() self.true += target.sum() diff --git a/reproduction/Summarization/model/TForiginal.py b/reproduction/Summarization/model/TForiginal.py index 9bcec292..e66bc061 100644 --- a/reproduction/Summarization/model/TForiginal.py +++ b/reproduction/Summarization/model/TForiginal.py @@ -83,7 +83,6 @@ class TransformerModel(nn.Module): :param input: [batch_size, N, seq_len] :param input_len: [batch_size, N] - :param return_atten: bool :return: """ # Sentence Encoder @@ -125,12 +124,12 @@ class TransformerModel(nn.Module): p_sent = self.wh(self.dec_output_state) # [batch, N, 2] idx = None - if self._hps == 0: + if self._hps.m == 0: prediction = p_sent.view(-1, 2).max(1)[1] prediction = prediction.view(batch_size, -1) else: mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] - mask_output = mask_output * input_len.float() + mask_output = mask_output.masked_fill(input_len.eq(0), 0) topk, idx = torch.topk(mask_output, self._hps.m) prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) prediction = prediction.long().view(batch_size, -1)