@@ -0,0 +1,12 @@ | |||||
{ | |||||
"n_layers": 16, | |||||
"layer_sum": false, | |||||
"layer_cat": false, | |||||
"lstm_hidden_size": 300, | |||||
"ffn_inner_hidden_size": 2048, | |||||
"n_head": 6, | |||||
"recurrent_dropout_prob": 0.1, | |||||
"atten_dropout_prob": 0.1, | |||||
"ffn_dropout_prob": 0.1, | |||||
"fix_mask": true | |||||
} |
@@ -0,0 +1,3 @@ | |||||
{ | |||||
} |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"n_layers": 12, | |||||
"hidden_size": 512, | |||||
"ffn_inner_hidden_size": 2048, | |||||
"n_head": 8, | |||||
"recurrent_dropout_prob": 0.1, | |||||
"atten_dropout_prob": 0.1, | |||||
"ffn_dropout_prob": 0.1 | |||||
} |
@@ -0,0 +1,188 @@ | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.core.const import Const | |||||
from tools.logger import * | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class SummarizationLoader(JsonLoader): | |||||
""" | |||||
读取summarization数据集,读取的DataSet包含fields:: | |||||
text: list(str),document | |||||
summary: list(str), summary | |||||
text_wd: list(list(str)),tokenized document | |||||
summary_wd: list(list(str)), tokenized summary | |||||
labels: list(int), | |||||
flatten_label: list(int), 0 or 1, flatten labels | |||||
domain: str, optional | |||||
tag: list(str), optional | |||||
数据来源: CNN_DailyMail Newsroom DUC | |||||
""" | |||||
def __init__(self): | |||||
super(SummarizationLoader, self).__init__() | |||||
def _load(self, path): | |||||
ds = super(SummarizationLoader, self)._load(path) | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||||
return ds | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True): | |||||
""" | |||||
:param paths: dict path for each dataset | |||||
:param vocab_size: int max_size for vocab | |||||
:param vocab_path: str vocab path | |||||
:param sent_max_len: int max token number of the sentence | |||||
:param doc_max_timesteps: int max sentence number of the document | |||||
:param domain: bool build vocab for publication, use 'X' for unknown | |||||
:param tag: bool build vocab for tag, use 'X' for unknown | |||||
:param load_vocab: bool build vocab (False) or load vocab (True) | |||||
:return: DataInfo | |||||
datasets: dict keys correspond to the paths dict | |||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||||
embeddings: optional | |||||
""" | |||||
def _pad_sent(text_wd): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
datasets = {} | |||||
train_ds = None | |||||
for key, value in paths.items(): | |||||
ds = self.load(value) | |||||
# pad sent | |||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||||
# pad document | |||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||||
# rename field | |||||
ds.rename_field("pad_text", Const.INPUT) | |||||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||||
ds.rename_field("pad_label", Const.TARGET) | |||||
# set input and target | |||||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
datasets[key] = ds | |||||
if "train" in key: | |||||
train_ds = datasets[key] | |||||
vocab_dict = {} | |||||
if load_vocab == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | |||||
if train_ds == None: | |||||
raise ValueError("Lack train file to build vocabulary!") | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||||
vocab_dict["vocab"] = vocabs | |||||
else: | |||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||||
word_list = [] | |||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
vocab_dict["vocab"] = vocabs | |||||
if domain == True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(train_ds, field_name="publication") | |||||
vocab_dict["domain"] = domaindict | |||||
if tag == True: | |||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||||
tagdict.from_dataset(train_ds, field_name="tag") | |||||
vocab_dict["tag"] = tagdict | |||||
for ds in datasets.values(): | |||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return DataInfo(vocabs=vocab_dict, datasets=datasets) | |||||
@@ -0,0 +1,136 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.init as init | |||||
import torch.nn.functional as F | |||||
from torch.autograd import Variable | |||||
from torch.distributions import Bernoulli | |||||
class DeepLSTM(nn.Module): | |||||
def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True): | |||||
super(DeepLSTM, self).__init__() | |||||
self.fix_mask = fix_mask | |||||
self.use_cuda = use_cuda | |||||
self.input_size = input_size | |||||
self.num_layers = num_layers | |||||
self.hidden_size = hidden_size | |||||
self.recurrent_dropout = recurrent_dropout | |||||
self.lstms = nn.ModuleList([None] * self.num_layers) | |||||
self.highway_gate_input = nn.ModuleList([None] * self.num_layers) | |||||
self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers) | |||||
self.highway_linear_input = nn.ModuleList([None] * self.num_layers) | |||||
# self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size)) | |||||
# init.xavier_normal_(self._input_w) | |||||
for l in range(self.num_layers): | |||||
input_dim = input_size if l == 0 else hidden_size | |||||
self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size) | |||||
self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size) | |||||
self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False) | |||||
# logger.info("[INFO] Initing W for LSTM .......") | |||||
for l in range(self.num_layers): | |||||
if use_orthnormal_init: | |||||
# logger.info("[INFO] Initing W using orthnormal init .......") | |||||
init.orthogonal_(self.lstms[l].weight_ih) | |||||
init.orthogonal_(self.lstms[l].weight_hh) | |||||
init.orthogonal_(self.highway_gate_input[l].weight.data) | |||||
init.orthogonal_(self.highway_gate_state[l].weight.data) | |||||
init.orthogonal_(self.highway_linear_input[l].weight.data) | |||||
else: | |||||
# logger.info("[INFO] Initing W using xavier_normal .......") | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value)) | |||||
def init_hidden(self, batch_size, hidden_size): | |||||
# the first is the hidden h | |||||
# the second is the cell c | |||||
if self.use_cuda: | |||||
return (torch.zeros(batch_size, hidden_size).cuda(), | |||||
torch.zeros(batch_size, hidden_size).cuda()) | |||||
else: | |||||
return (torch.zeros(batch_size, hidden_size), | |||||
torch.zeros(batch_size, hidden_size)) | |||||
def forward(self, inputs, input_masks, Train): | |||||
''' | |||||
inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) | |||||
input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) | |||||
''' | |||||
batch_size, seq_len = inputs[0].size(1), inputs[0].size(0) | |||||
# inputs[0] = torch.matmul(inputs[0], self._input_w) | |||||
# input_masks[0] = input_masks[0].unsqueeze(-1).expand(seq_len, batch_size, self.hidden_size) | |||||
self.inputs = inputs | |||||
self.input_masks = input_masks | |||||
if self.fix_mask: | |||||
self.output_dropout_layers = [None] * self.num_layers | |||||
for l in range(self.num_layers): | |||||
binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout | |||||
# This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. | |||||
# from allennlp.nn.util.py | |||||
self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout) | |||||
if self.use_cuda: | |||||
self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda() | |||||
for l in range(self.num_layers): | |||||
h, c = self.init_hidden(batch_size, self.hidden_size) | |||||
outputs_list = [] | |||||
for t in range(len(self.inputs[l])): | |||||
x = self.inputs[l][t] | |||||
m = self.input_masks[l][t].float() | |||||
h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size] | |||||
r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h)) | |||||
lx = self.highway_linear_input[l](x) # [batch, hidden_size] | |||||
h_temp = r * h_temp + (1 - r) * lx | |||||
if Train: | |||||
if self.fix_mask: | |||||
h_temp = self.output_dropout_layers[l] * h_temp | |||||
else: | |||||
h_temp = F.dropout(h_temp, p=self.recurrent_dropout) | |||||
h = m * h_temp + (1 - m) * h | |||||
c = m * c_temp + (1 - m) * c | |||||
outputs_list.append(h) | |||||
outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size] | |||||
self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size] | |||||
self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0) | |||||
self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size] | |||||
# flip -2 layer | |||||
# self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0) | |||||
# concat last two layer | |||||
# self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1) | |||||
self.output_state = self.output_state[-1].transpose(0, 1) | |||||
assert self.output_state.size() == (batch_size, seq_len, self.hidden_size) | |||||
return self.output_state | |||||
@staticmethod | |||||
def flip(x, dim): | |||||
xsize = x.size() | |||||
dim = x.dim() + dim if dim < 0 else dim | |||||
x = x.contiguous() | |||||
x = x.view(-1, *xsize[dim:]).contiguous() | |||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, | |||||
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] | |||||
return x.view(xsize) |
@@ -0,0 +1,566 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.nn.init as init | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
# from tools.logger import * | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
WORD_PAD = "[PAD]" | |||||
class Encoder(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: | |||||
word_emb_dim: word embedding dimension | |||||
sent_max_len: max token number in the sentence | |||||
output_channel: output channel for cnn | |||||
min_kernel_size: min kernel size for cnn | |||||
max_kernel_size: max kernel size for cnn | |||||
word_embedding: bool, use word embedding or not | |||||
embedding_path: word embedding path | |||||
embed_train: bool, whether to train word embedding | |||||
cuda: bool, use cuda or not | |||||
:param vocab: FastNLP.Vocabulary | |||||
""" | |||||
super(Encoder, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self.embed = embed | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
print("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
print("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len] | |||||
batch_size, N, _ = input.size() | |||||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||||
input_sent_len = ((input!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class DomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(DomainEncoder, self).__init__(hps, vocab) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
sent_embedding = super().forward(input) | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class MultiDomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||||
self.domain_size = domaindict.size() | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size, domain_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
# logger.info(domain[:5, :]) | |||||
sent_embedding = super().forward(input) | |||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, hps): | |||||
super(BertEncoder, self).__init__() | |||||
from pytorch_pretrained_bert.modeling import BertModel | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
self._cuda = hps.cuda | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||||
self._bert.eval() | |||||
for p in self._bert.parameters(): | |||||
p.requires_grad = False | |||||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
for i in range(batch_size): | |||||
tokens_id = inputs[i] | |||||
input_mask = input_masks[i] | |||||
sent_len = enc_sent_len[i] | |||||
input_ids = tokens_id.unsqueeze(0) | |||||
input_mask = input_mask.unsqueeze(0) | |||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 1 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 511: | |||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class BertTagEncoder(BertEncoder): | |||||
def __init__(self, hps, domaindict): | |||||
super(BertTagEncoder, self).__init__(hps) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||||
batch_size, N = enc_sent_len.size() | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class ELMoEndoer(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||||
batch_size, N, seq_len, _ = input.size() | |||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
# logger.debug(input_sent_len.view(batch_size, -1)) | |||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||||
enc_embed_input = self.embed_proj(enc_embed_input) | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in input_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class ELMoEndoer2(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer2, self).__init__() | |||||
self._hps = hps | |||||
self._cuda = hps.cuda | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
sent_embedding: [batch, N, D] | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||||
# print("END elmo") | |||||
for i in range(batch_size): | |||||
sent_len = enc_sent_len[i] # [1, N] | |||||
out = elmo_output[i] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 0 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 512: | |||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding |
@@ -0,0 +1,103 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import * | |||||
from torch.distributions import * | |||||
from .Encoder import Encoder | |||||
from .DeepLSTM import DeepLSTM | |||||
from transformer.SubLayers import MultiHeadAttention,PositionwiseFeedForward | |||||
class SummarizationModel(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: hyperparameters for the model | |||||
:param vocab: vocab object | |||||
""" | |||||
super(SummarizationModel, self).__init__() | |||||
self._hps = hps | |||||
# sentence encoder | |||||
self.encoder = Encoder(hps, embed) | |||||
# Multi-layer highway lstm | |||||
self.num_layers = hps.n_layers | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.lstm_hidden_size = hps.lstm_hidden_size | |||||
self.recurrent_dropout = hps.recurrent_dropout_prob | |||||
self.deep_lstm = DeepLSTM(self.sent_embedding_size, self.lstm_hidden_size, self.num_layers, self.recurrent_dropout, | |||||
hps.use_orthnormal_init, hps.fix_mask, hps.cuda) | |||||
# Multi-head attention | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.lstm_hidden_size / hps.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.slf_attn = MultiHeadAttention(hps.n_head, self.lstm_hidden_size , self.d_k, self.d_v, dropout=hps.atten_dropout_prob) | |||||
self.pos_ffn = PositionwiseFeedForward(self.d_v, self.d_inner, dropout = hps.ffn_dropout_prob) | |||||
self.wh = nn.Linear(self.d_v, 2) | |||||
def forward(self, input, input_len, Train): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], word idx long tensor | |||||
:param input_len: [batch_size, N], 1 for sentence and 0 for padding | |||||
:param Train: True for train and False for eval and test | |||||
:param return_atten: True or False to return multi-head attention output self.output_slf_attn | |||||
:return: | |||||
p_sent: [batch_size, N, 2] | |||||
output_slf_attn: (option) [n_head, batch_size, N, N] | |||||
""" | |||||
# -- Sentence Encoder | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
# -- Multi-layer highway lstm | |||||
input_len = input_len.float() # [batch, N] | |||||
self.inputs = [None] * (self.num_layers + 1) | |||||
self.input_masks = [None] * (self.num_layers + 1) | |||||
self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] | |||||
self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) | |||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
slf_attn_mask = input_len.eq(0.0) # [batch, N], 1 for padding | |||||
slf_attn_mask = slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
# -- Multi-head attention | |||||
self.atten_output, self.output_slf_attn = self.slf_attn(self.lstm_output_state, self.lstm_output_state, self.lstm_output_state, mask=slf_attn_mask) | |||||
self.atten_output *= input_len.unsqueeze(2) # [batch_size, N, lstm_hidden_size = (n_head * d_v)] | |||||
self.multi_atten_output = self.atten_output.view(batch_size, N, self.n_head, self.d_v) # [batch_size, N, n_head, d_v] | |||||
self.multi_atten_context = self.multi_atten_output[:, :, 0::2, :].sum(2) - self.multi_atten_output[:, :, 1::2, :].sum(2) # [batch_size, N, d_v] | |||||
# -- Position-wise Feed-Forward Networks | |||||
self.output_state = self.pos_ffn(self.multi_atten_context) | |||||
self.output_state = self.output_state * input_len.unsqueeze(2) # [batch_size, N, d_v] | |||||
p_sent = self.wh(self.output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output.masked_fill(input_len.eq(0), 0) | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} |
@@ -0,0 +1,55 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.losses import LossBase | |||||
from tools.logger import * | |||||
class MyCrossEntropyLoss(LossBase): | |||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.padding_idx = padding_idx | |||||
self.reduce = reduce | |||||
def get_loss(self, pred, target, mask): | |||||
""" | |||||
:param pred: [batch, N, 2] | |||||
:param target: [batch, N] | |||||
:param input_mask: [batch, N] | |||||
:return: | |||||
""" | |||||
# logger.debug(pred[0:5, :, :]) | |||||
# logger.debug(target[0:5, :]) | |||||
batch, N, _ = pred.size() | |||||
pred = pred.view(-1, 2) | |||||
target = target.view(-1) | |||||
loss = F.cross_entropy(input=pred, target=target, | |||||
ignore_index=self.padding_idx, reduction=self.reduce) | |||||
loss = loss.view(batch, -1) | |||||
loss = loss.masked_fill(mask.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
return loss | |||||
@@ -0,0 +1,171 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import division | |||||
import torch | |||||
from rouge import Rouge | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.metrics import MetricBase | |||||
from tools.logger import * | |||||
from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | |||||
class LabelFMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
self.match = 0.0 | |||||
self.pred = 0.0 | |||||
self.true = 0.0 | |||||
self.match_true = 0.0 | |||||
self.total = 0.0 | |||||
def evaluate(self, pred, target): | |||||
""" | |||||
:param pred: [batch, N] int | |||||
:param target: [batch, N] int | |||||
:return: | |||||
""" | |||||
target = target.data | |||||
pred = pred.data | |||||
# logger.debug(pred.size()) | |||||
# logger.debug(pred[:5,:]) | |||||
batch, N = pred.size() | |||||
self.pred += pred.sum() | |||||
self.true += target.sum() | |||||
self.match += (pred == target).sum() | |||||
self.match_true += ((pred == target) & (pred == 1)).sum() | |||||
self.total += batch * N | |||||
def get_metric(self, reset=True): | |||||
self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total | |||||
logger.debug((self.match,self.pred, self.true, self.match_true, self.total)) | |||||
try: | |||||
accu = self.match / self.total | |||||
precision = self.match_true / self.pred | |||||
recall = self.match_true / self.true | |||||
F = 2 * precision * recall / (precision + recall) | |||||
except ZeroDivisionError: | |||||
F = 0.0 | |||||
logger.error("[Error] float division by zero") | |||||
if reset: | |||||
self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0 | |||||
ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()} | |||||
logger.info(ret) | |||||
return ret | |||||
class RougeMetric(MetricBase): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__() | |||||
self._hps = hps | |||||
self._init_param_map(pred=pred, text=text, summary=refer) | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
def evaluate(self, pred, text, summary): | |||||
""" | |||||
:param prediction: [batch, N] | |||||
:param text: [batch, N] | |||||
:param summary: [batch, N] | |||||
:return: | |||||
""" | |||||
batch_size, N = pred.size() | |||||
for j in range(batch_size): | |||||
original_article_sents = text[j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(summary[j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if | |||||
pred[j][id] == 1 and id < sent_max_number) | |||||
if sent_max_number < self._hps.m and len(hyps) <= 1: | |||||
print("sent_max_number is too short %d, Skip!", sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
self.hyps.append(hyps) | |||||
self.refers.append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("pred:") | |||||
logger.debug(pred[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug(refer) | |||||
continue | |||||
def get_metric(self, reset=True): | |||||
pass | |||||
class FastRougeMetric(RougeMetric): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__(hps, pred, text, refer) | |||||
def get_metric(self, reset=True): | |||||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||||
if len(self.hyps) == 0 or len(self.refers) == 0 : | |||||
logger.error("During testing, no hyps or refers is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(self.hyps, self.refers, avg=True) | |||||
if reset: | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
logger.info(scores_all) | |||||
return scores_all | |||||
class PyRougeMetric(RougeMetric): | |||||
def __init__(self, hps, pred=None, text=None, refer=None): | |||||
super().__init__(hps, pred, text, refer) | |||||
def get_metric(self, reset=True): | |||||
logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers)) | |||||
if len(self.hyps) == 0 or len(self.refers) == 0: | |||||
logger.error("During testing, no hyps or refers is selected!") | |||||
return | |||||
if isinstance(self.refers[0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = pyrouge_score_all_multi(self.hyps, self.refers) | |||||
else: | |||||
scores_all = pyrouge_score_all(self.hyps, self.refers) | |||||
if reset: | |||||
self.hyps = [] | |||||
self.refers = [] | |||||
logger.info(scores_all) | |||||
return scores_all | |||||
@@ -0,0 +1,143 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .Encoder import Encoder | |||||
# from tools.Encoder import Encoder | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
from tools.logger import * | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from transformer.Layers import EncoderLayer | |||||
class TransformerModel(nn.Module): | |||||
def __init__(self, hps, embed): | |||||
""" | |||||
:param hps: | |||||
min_kernel_size: min kernel size for cnn encoder | |||||
max_kernel_size: max kernel size for cnn encoder | |||||
output_channel: output_channel number for cnn encoder | |||||
hidden_size: hidden size for transformer | |||||
n_layers: transfromer encoder layer | |||||
n_head: multi head attention for transformer | |||||
ffn_inner_hidden_size: FFN hiddens size | |||||
atten_dropout_prob: dropout size | |||||
doc_max_timesteps: max sentence number of the document | |||||
:param vocab: | |||||
""" | |||||
super(TransformerModel, self).__init__() | |||||
self._hps = hps | |||||
self.encoder = Encoder(hps, embed) | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.hidden_size = hps.hidden_size | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.num_layers = hps.n_layers | |||||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v, | |||||
dropout=hps.atten_dropout_prob) | |||||
for _ in range(self.num_layers)]) | |||||
self.wh = nn.Linear(self.hidden_size, 2) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
:param input: [batch_size, N, seq_len] | |||||
:param input_len: [batch_size, N] | |||||
:return: | |||||
""" | |||||
# Sentence Encoder | |||||
input = words | |||||
input_len = seq_len | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
input_len = input_len.float() # [batch, N] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||||
sent_pos = torch.Tensor( | |||||
[np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||||
enc_output_state = self.projection(self.sent_embedding) | |||||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||||
enc_input_list = [] | |||||
for enc_layer in self.layer_stack: | |||||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||||
# enc_slf_attn = [n_head * batch_size, N, N] | |||||
enc_input, enc_slf_atten = enc_layer(enc_input, non_pad_mask=self.non_pad_mask, | |||||
slf_attn_mask=self.slf_attn_mask) | |||||
enc_input_list += [enc_input] | |||||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||||
self.dec_output_state = self.dec_output_state.sum(0) | |||||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output.masked_fill(input_len.eq(0), 0) | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
# logger.debug(((p_sent.size(), prediction.size(), idx.size()))) | |||||
return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx} | |||||
@@ -0,0 +1,138 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .Encoder import Encoder | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
class TransformerModel(nn.Module): | |||||
def __init__(self, hps, vocab): | |||||
""" | |||||
:param hps: | |||||
min_kernel_size: min kernel size for cnn encoder | |||||
max_kernel_size: max kernel size for cnn encoder | |||||
output_channel: output_channel number for cnn encoder | |||||
hidden_size: hidden size for transformer | |||||
n_layers: transfromer encoder layer | |||||
n_head: multi head attention for transformer | |||||
ffn_inner_hidden_size: FFN hiddens size | |||||
atten_dropout_prob: dropout size | |||||
doc_max_timesteps: max sentence number of the document | |||||
:param vocab: | |||||
""" | |||||
super(TransformerModel, self).__init__() | |||||
self._hps = hps | |||||
self._vocab = vocab | |||||
self.encoder = Encoder(hps, vocab) | |||||
self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel | |||||
self.hidden_size = hps.hidden_size | |||||
self.n_head = hps.n_head | |||||
self.d_v = self.d_k = int(self.hidden_size / self.n_head) | |||||
self.d_inner = hps.ffn_inner_hidden_size | |||||
self.num_layers = hps.n_layers | |||||
self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size) | |||||
self.sent_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||||
for _ in range(self.num_layers)]) | |||||
self.wh = nn.Linear(self.hidden_size, 2) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
:param input: [batch_size, N, seq_len] | |||||
:param input_len: [batch_size, N] | |||||
:param return_atten: bool | |||||
:return: | |||||
""" | |||||
# Sentence Encoder | |||||
input = words | |||||
input_len = seq_len | |||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||||
input_len = input_len.float() # [batch, N] | |||||
# -- Prepare masks | |||||
batch_size, N = input_len.size() | |||||
self.slf_attn_mask = input_len.eq(0.0) # [batch, N] | |||||
self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N] | |||||
self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1] | |||||
input_doc_len = input_len.sum(dim=1).int() # [batch] | |||||
sent_pos = torch.Tensor([np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len]) | |||||
sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long() | |||||
enc_output_state = self.projection(self.sent_embedding) | |||||
enc_input = enc_output_state + self.sent_pos_embed(sent_pos) | |||||
# self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask | |||||
enc_input_list = [] | |||||
for enc_layer in self.layer_stack: | |||||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||||
# enc_slf_attn = [n_head * batch_size, N, N] | |||||
enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||||
enc_input_list += [enc_input] | |||||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||||
self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1) | |||||
self.dec_output_state = self.dec_output_state.sum(0) | |||||
p_sent = self.wh(self.dec_output_state) # [batch, N, 2] | |||||
idx = None | |||||
if self._hps.m == 0: | |||||
prediction = p_sent.view(-1, 2).max(1)[1] | |||||
prediction = prediction.view(batch_size, -1) | |||||
else: | |||||
mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N] | |||||
mask_output = mask_output * input_len.float() | |||||
topk, idx = torch.topk(mask_output, self._hps.m) | |||||
prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1) | |||||
prediction = prediction.long().view(batch_size, -1) | |||||
if self._hps.cuda: | |||||
prediction = prediction.cuda() | |||||
# print((p_sent.size(), prediction.size(), idx.size())) | |||||
# [batch, N, 2], [batch, N], [batch, hps.m] | |||||
return {"pred": p_sent, "prediction": prediction, "pred_idx": idx} | |||||
@@ -0,0 +1,24 @@ | |||||
import unittest | |||||
from ..data.dataloader import SummarizationLoader | |||||
class TestSummarizationLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths) | |||||
print(data.datasets) | |||||
def test_case2(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, domain=True) | |||||
print(data.datasets, data.vocabs) | |||||
def test_case3(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, tag=True) | |||||
print(data.datasets, data.vocabs) |
@@ -0,0 +1,36 @@ | |||||
import unittest | |||||
import sys | |||||
sys.path.append('..') | |||||
from data.dataloader import SummarizationLoader | |||||
vocab_size = 100000 | |||||
vocab_path = "testdata/vocab" | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
class TestSummarizationLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps) | |||||
print(data.datasets) | |||||
def test_case2(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, domain=True) | |||||
print(data.datasets, data.vocabs) | |||||
def test_case3(self): | |||||
sum_loader = SummarizationLoader() | |||||
paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} | |||||
data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, tag=True) | |||||
print(data.datasets, data.vocabs) | |||||
@@ -0,0 +1,56 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import os | |||||
import sys | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | |||||
from fastNLP.core.const import Const | |||||
from data.dataloader import SummarizationLoader | |||||
from tools.data import ExampleSet, Vocab | |||||
vocab_size = 100000 | |||||
vocab_path = "test/testdata/vocab" | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
# paths = {"train": "test/testdata/train.jsonl", "valid": "test/testdata/val.jsonl"} | |||||
paths = {"train": "/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl", "valid": "/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl"} | |||||
sum_loader = SummarizationLoader() | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, load_vocab_file=True) | |||||
trainset = dataInfo.datasets["train"] | |||||
vocab = Vocab(vocab_path, vocab_size) | |||||
dataset = ExampleSet(paths["train"], vocab, doc_max_timesteps, sent_max_len) | |||||
# print(trainset[0]["text"]) | |||||
# print(dataset.get_example(0).original_article_sents) | |||||
# print(trainset[0]["words"]) | |||||
# print(dataset[0][0].numpy().tolist()) | |||||
b_size = len(trainset) | |||||
for i in range(b_size): | |||||
if i <= 7327: | |||||
continue | |||||
print(trainset[i][Const.INPUT]) | |||||
print(dataset[i][0].numpy().tolist()) | |||||
assert trainset[i][Const.INPUT] == dataset[i][0].numpy().tolist(), i | |||||
assert trainset[i][Const.INPUT_LEN] == dataset[i][2].numpy().tolist(), i | |||||
assert trainset[i][Const.TARGET] == dataset[i][1].numpy().tolist(), i |
@@ -0,0 +1,135 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import os | |||||
import sys | |||||
import time | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from fastNLP.core.callback import Callback, EarlyStopError | |||||
from tools.logger import * | |||||
class TrainCallback(Callback): | |||||
def __init__(self, hps, patience=3, quit_all=True): | |||||
super().__init__() | |||||
self._hps = hps | |||||
self.patience = patience | |||||
self.wait = 0 | |||||
if type(quit_all) != bool: | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||||
self.quit_all = quit_all | |||||
def on_epoch_begin(self): | |||||
self.epoch_start_time = time.time() | |||||
# def on_loss_begin(self, batch_y, predict_y): | |||||
# """ | |||||
# | |||||
# :param batch_y: dict | |||||
# input_len: [batch, N] | |||||
# :param predict_y: dict | |||||
# p_sent: [batch, N, 2] | |||||
# :return: | |||||
# """ | |||||
# input_len = batch_y[Const.INPUT_LEN] | |||||
# batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) | |||||
# # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) | |||||
# # logger.debug(predict_y["p_sent"][0:5,:,:]) | |||||
def on_backward_begin(self, loss): | |||||
""" | |||||
:param loss: [] | |||||
:return: | |||||
""" | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
def on_backward_end(self): | |||||
if self._hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | |||||
def on_epoch_end(self): | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' | |||||
.format(self.epoch, (time.time() - self.epoch_start_time))) | |||||
def on_valid_begin(self): | |||||
self.valid_start_time = time.time() | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | |||||
.format(self.epoch, (time.time() - self.valid_start_time))) | |||||
# early stop | |||||
if not is_better_eval: | |||||
if self.wait == self.patience: | |||||
train_dir = os.path.join(self._hps.save_root, "train") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
raise EarlyStopError("Early stopping raised.") | |||||
else: | |||||
self.wait += 1 | |||||
else: | |||||
self.wait = 0 | |||||
# lr descent | |||||
if self._hps.lr_descent: | |||||
new_lr = max(5e-6, self._hps.lr / (self.epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, KeyboardInterrupt): | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
train_dir = os.path.join(self._hps.save_root, "train") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
if self.quit_all is True: | |||||
sys.exit(0) # 直接退出程序 | |||||
else: | |||||
pass | |||||
else: | |||||
raise exception # 抛出陌生Error | |||||
@@ -0,0 +1,562 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import * | |||||
import torch.nn.init as init | |||||
import data | |||||
from tools.logger import * | |||||
from transformer.Models import get_sinusoid_encoding_table | |||||
class Encoder(nn.Module): | |||||
def __init__(self, hps, vocab): | |||||
super(Encoder, self).__init__() | |||||
self._hps = hps | |||||
self._vocab = vocab | |||||
self.sent_max_len = hps.sent_max_len | |||||
vocab_size = len(vocab) | |||||
logger.info("[INFO] Vocabulary size is %d", vocab_size) | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]')) | |||||
if hps.word_embedding: | |||||
word2vec = data.Word_Embedding(hps.embedding_path, vocab) | |||||
word_vecs = word2vec.load_my_vecs(embed_size) | |||||
# pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size) | |||||
pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size) | |||||
pretrained_weight = np.array(pretrained_weight) | |||||
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||||
self.embed.weight.requires_grad = hps.embed_train | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len] | |||||
vocab = self._vocab | |||||
batch_size, N, _ = input.size() | |||||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||||
input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1] | |||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class DomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(DomainEncoder, self).__init__(hps, vocab) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
sent_embedding = super().forward(input) | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class MultiDomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||||
self.domain_size = domaindict.size() | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size, domain_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
# logger.info(domain[:5, :]) | |||||
sent_embedding = super().forward(input) | |||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, hps): | |||||
super(BertEncoder, self).__init__() | |||||
from pytorch_pretrained_bert.modeling import BertModel | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
self._cuda = hps.cuda | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||||
self._bert.eval() | |||||
for p in self._bert.parameters(): | |||||
p.requires_grad = False | |||||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
for i in range(batch_size): | |||||
tokens_id = inputs[i] | |||||
input_mask = input_masks[i] | |||||
sent_len = enc_sent_len[i] | |||||
input_ids = tokens_id.unsqueeze(0) | |||||
input_mask = input_mask.unsqueeze(0) | |||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 1 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 511: | |||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class BertTagEncoder(BertEncoder): | |||||
def __init__(self, hps, domaindict): | |||||
super(BertTagEncoder, self).__init__(hps) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||||
batch_size, N = enc_sent_len.size() | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class ELMoEndoer(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||||
batch_size, N, seq_len, _ = input.size() | |||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
logger.debug(input_sent_len.view(batch_size, -1)) | |||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||||
enc_embed_input = self.embed_proj(enc_embed_input) | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in input_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class ELMoEndoer2(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer2, self).__init__() | |||||
self._hps = hps | |||||
self._cuda = hps.cuda | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
sent_embedding: [batch, N, D] | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||||
# print("END elmo") | |||||
for i in range(batch_size): | |||||
sent_len = enc_sent_len[i] # [1, N] | |||||
out = elmo_output[i] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 0 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 512: | |||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding |
@@ -0,0 +1,41 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import torch | |||||
import numpy as np | |||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
''' Sinusoid position encoding table ''' | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) |
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,479 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" | |||||
import os | |||||
import re | |||||
import glob | |||||
import copy | |||||
import random | |||||
import json | |||||
import collections | |||||
from itertools import combinations | |||||
import numpy as np | |||||
from random import shuffle | |||||
import torch.utils.data | |||||
import time | |||||
import pickle | |||||
from nltk.tokenize import sent_tokenize | |||||
import utils | |||||
from logger import * | |||||
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | |||||
SENTENCE_START = '<s>' | |||||
SENTENCE_END = '</s>' | |||||
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence | |||||
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words | |||||
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence | |||||
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences | |||||
# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file. | |||||
class Vocab(object): | |||||
"""Vocabulary class for mapping between words and ids (integers)""" | |||||
def __init__(self, vocab_file, max_size): | |||||
""" | |||||
Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. | |||||
:param vocab_file: string; path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. | |||||
:param max_size: int; The maximum size of the resulting Vocabulary. | |||||
""" | |||||
self._word_to_id = {} | |||||
self._id_to_word = {} | |||||
self._count = 0 # keeps track of total number of words in the Vocab | |||||
# [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. | |||||
for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]: | |||||
self._word_to_id[w] = self._count | |||||
self._id_to_word[self._count] = w | |||||
self._count += 1 | |||||
# Read the vocab file and add words up to max_size | |||||
with open(vocab_file, 'r', encoding='utf8') as vocab_f: #New : add the utf8 encoding to prevent error | |||||
cnt = 0 | |||||
for line in vocab_f: | |||||
cnt += 1 | |||||
pieces = line.split("\t") | |||||
# pieces = line.split() | |||||
w = pieces[0] | |||||
# print(w) | |||||
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: | |||||
raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) | |||||
if w in self._word_to_id: | |||||
logger.error('Duplicated word in vocabulary file Line %d : %s' % (cnt, w)) | |||||
continue | |||||
self._word_to_id[w] = self._count | |||||
self._id_to_word[self._count] = w | |||||
self._count += 1 | |||||
if max_size != 0 and self._count >= max_size: | |||||
logger.info("[INFO] max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) | |||||
break | |||||
logger.info("[INFO] Finished constructing vocabulary of %i total words. Last word added: %s", self._count, self._id_to_word[self._count-1]) | |||||
def word2id(self, word): | |||||
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" | |||||
if word not in self._word_to_id: | |||||
return self._word_to_id[UNKNOWN_TOKEN] | |||||
return self._word_to_id[word] | |||||
def id2word(self, word_id): | |||||
"""Returns the word (string) corresponding to an id (integer).""" | |||||
if word_id not in self._id_to_word: | |||||
raise ValueError('Id not found in vocab: %d' % word_id) | |||||
return self._id_to_word[word_id] | |||||
def size(self): | |||||
"""Returns the total size of the vocabulary""" | |||||
return self._count | |||||
def word_list(self): | |||||
"""Return the word list of the vocabulary""" | |||||
return self._word_to_id.keys() | |||||
class Word_Embedding(object): | |||||
def __init__(self, path, vocab): | |||||
""" | |||||
:param path: string; the path of word embedding | |||||
:param vocab: object; | |||||
""" | |||||
logger.info("[INFO] Loading external word embedding...") | |||||
self._path = path | |||||
self._vocablist = vocab.word_list() | |||||
self._vocab = vocab | |||||
def load_my_vecs(self, k=200): | |||||
"""Load word embedding""" | |||||
word_vecs = {} | |||||
with open(self._path, encoding="utf-8") as f: | |||||
count = 0 | |||||
lines = f.readlines()[1:] | |||||
for line in lines: | |||||
values = line.split(" ") | |||||
word = values[0] | |||||
count += 1 | |||||
if word in self._vocablist: # whether to judge if in vocab | |||||
vector = [] | |||||
for count, val in enumerate(values): | |||||
if count == 0: | |||||
continue | |||||
if count <= k: | |||||
vector.append(float(val)) | |||||
word_vecs[word] = vector | |||||
return word_vecs | |||||
def add_unknown_words_by_zero(self, word_vecs, k=200): | |||||
"""Solve unknown by zeros""" | |||||
zero = [0.0] * k | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = zero | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||||
return list_word2vec | |||||
def add_unknown_words_by_avg(self, word_vecs, k=200): | |||||
"""Solve unknown by avg word embedding""" | |||||
# solve unknown words inplaced by zero list | |||||
word_vecs_numpy = [] | |||||
for word in self._vocablist: | |||||
if word in word_vecs: | |||||
word_vecs_numpy.append(word_vecs[word]) | |||||
col = [] | |||||
for i in range(k): | |||||
sum = 0.0 | |||||
for j in range(int(len(word_vecs_numpy))): | |||||
sum += word_vecs_numpy[j][i] | |||||
sum = round(sum, 6) | |||||
col.append(sum) | |||||
zero = [] | |||||
for m in range(k): | |||||
avg = col[m] / int(len(word_vecs_numpy)) | |||||
avg = round(avg, 6) | |||||
zero.append(float(avg)) | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = zero | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] External Word Embedding iov count: %d, oov count: %d", iov, oov) | |||||
return list_word2vec | |||||
def add_unknown_words_by_uniform(self, word_vecs, uniform=0.25, k=200): | |||||
"""Solve unknown word by uniform(-0.25,0.25)""" | |||||
list_word2vec = [] | |||||
oov = 0 | |||||
iov = 0 | |||||
for i in range(self._vocab.size()): | |||||
word = self._vocab.id2word(i) | |||||
if word not in word_vecs: | |||||
oov += 1 | |||||
word_vecs[word] = np.random.uniform(-1 * uniform, uniform, k).round(6).tolist() | |||||
list_word2vec.append(word_vecs[word]) | |||||
else: | |||||
iov += 1 | |||||
list_word2vec.append(word_vecs[word]) | |||||
logger.info("[INFO] oov count %d, iov count %d", oov, iov) | |||||
return list_word2vec | |||||
# load word embedding | |||||
def load_my_vecs_freq1(self, freqs, pro): | |||||
word_vecs = {} | |||||
with open(self._path, encoding="utf-8") as f: | |||||
freq = 0 | |||||
lines = f.readlines()[1:] | |||||
for line in lines: | |||||
values = line.split(" ") | |||||
word = values[0] | |||||
if word in self._vocablist: # whehter to judge if in vocab | |||||
if freqs[word] == 1: | |||||
a = np.random.uniform(0, 1, 1).round(2) | |||||
if pro < a: | |||||
continue | |||||
vector = [] | |||||
for count, val in enumerate(values): | |||||
if count == 0: | |||||
continue | |||||
vector.append(float(val)) | |||||
word_vecs[word] = vector | |||||
return word_vecs | |||||
class DomainDict(object): | |||||
"""Domain embedding for Newsroom""" | |||||
def __init__(self, path): | |||||
self.domain_list = self.readDomainlist(path) | |||||
# self.domain_list = ["foxnews.com", "cnn.com", "mashable.com", "nytimes.com", "washingtonpost.com"] | |||||
self.domain_number = len(self.domain_list) | |||||
self._domain_to_id = {} | |||||
self._id_to_domain = {} | |||||
self._cnt = 0 | |||||
self._domain_to_id["X"] = self._cnt | |||||
self._id_to_domain[self._cnt] = "X" | |||||
self._cnt += 1 | |||||
for i in range(self.domain_number): | |||||
domain = self.domain_list[i] | |||||
self._domain_to_id[domain] = self._cnt | |||||
self._id_to_domain[self._cnt] = domain | |||||
self._cnt += 1 | |||||
def readDomainlist(self, path): | |||||
domain_list = [] | |||||
with open(path) as f: | |||||
for line in f: | |||||
domain_list.append(line.split("\t")[0].strip()) | |||||
logger.info(domain_list) | |||||
return domain_list | |||||
def domain2id(self, domain): | |||||
""" Returns the id (integer) of a domain (string). Returns "X" for unknow domain. | |||||
:param domain: string | |||||
:return: id; int | |||||
""" | |||||
if domain in self.domain_list: | |||||
return self._domain_to_id[domain] | |||||
else: | |||||
logger.info(domain) | |||||
return self._domain_to_id["X"] | |||||
def id2domain(self, domain_id): | |||||
""" Returns the domain (string) corresponding to an id (integer). | |||||
:param id: int; | |||||
:return: domain: string | |||||
""" | |||||
if domain_id not in self._id_to_domain: | |||||
raise ValueError('Id not found in DomainDict: %d' % domain_id) | |||||
return self._id_to_domain[id] | |||||
def size(self): | |||||
return self._cnt | |||||
class Example(object): | |||||
"""Class representing a train/val/test example for text summarization.""" | |||||
def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label, domainid=None): | |||||
""" Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. | |||||
:param article_sents: list of strings; one per article sentence. each token is separated by a single space. | |||||
:param abstract_sents: list of strings; one per abstract sentence. In each sentence, each token is separated by a single space. | |||||
:param domainid: int; publication of the example | |||||
:param vocab: Vocabulary object | |||||
:param sent_max_len: int; the maximum length of each sentence, padding all sentences to this length | |||||
:param label: list of int; the index of selected sentences | |||||
""" | |||||
self.sent_max_len = sent_max_len | |||||
self.enc_sent_len = [] | |||||
self.enc_sent_input = [] | |||||
self.enc_sent_input_pad = [] | |||||
# origin_cnt = len(article_sents) | |||||
# article_sents = [re.sub(r"\n+\t+", " ", sent) for sent in article_sents] | |||||
# assert origin_cnt == len(article_sents) | |||||
# Process the article | |||||
for sent in article_sents: | |||||
article_words = sent.split() | |||||
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | |||||
self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
self._pad_encoder_input(vocab.word2id('[PAD]')) | |||||
# Store the original strings | |||||
self.original_article = " ".join(article_sents) | |||||
self.original_article_sents = article_sents | |||||
if isinstance(abstract_sents[0], list): | |||||
logger.debug("[INFO] Multi Reference summaries!") | |||||
self.original_abstract_sents = [] | |||||
self.original_abstract = [] | |||||
for summary in abstract_sents: | |||||
self.original_abstract_sents.append([sent.strip() for sent in summary]) | |||||
self.original_abstract.append("\n".join([sent.replace("\n", "") for sent in summary])) | |||||
else: | |||||
self.original_abstract_sents = [sent.replace("\n", "") for sent in abstract_sents] | |||||
self.original_abstract = "\n".join(self.original_abstract_sents) | |||||
# Store the label | |||||
self.label = np.zeros(len(article_sents), dtype=int) | |||||
if label != []: | |||||
self.label[np.array(label)] = 1 | |||||
self.label = list(self.label) | |||||
# Store the publication | |||||
if domainid != None: | |||||
if domainid == 0: | |||||
logger.debug("domain id = 0!") | |||||
self.domain = domainid | |||||
def _pad_encoder_input(self, pad_id): | |||||
""" | |||||
:param pad_id: int; token pad id | |||||
:return: | |||||
""" | |||||
max_len = self.sent_max_len | |||||
for i in range(len(self.enc_sent_input)): | |||||
article_words = self.enc_sent_input[i] | |||||
if len(article_words) > max_len: | |||||
article_words = article_words[:max_len] | |||||
while len(article_words) < max_len: | |||||
article_words.append(pad_id) | |||||
self.enc_sent_input_pad.append(article_words) | |||||
class ExampleSet(torch.utils.data.Dataset): | |||||
""" Constructor: Dataset of example(object) """ | |||||
def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||||
""" Initializes the ExampleSet with the path of data | |||||
:param data_path: string; the path of data | |||||
:param vocab: object; | |||||
:param doc_max_timesteps: int; the maximum sentence number of a document, each example should pad sentences to this length | |||||
:param sent_max_len: int; the maximum token number of a sentence, each sentence should pad tokens to this length | |||||
:param domaindict: object; the domain dict to embed domain | |||||
""" | |||||
self.domaindict = domaindict | |||||
if domaindict: | |||||
logger.info("[INFO] Use domain information in the dateset!") | |||||
if randomX==True: | |||||
logger.info("[INFO] Random some example to unknow domain X!") | |||||
self.randomP = 0.1 | |||||
logger.info("[INFO] Start reading ExampleSet") | |||||
start = time.time() | |||||
self.example_list = [] | |||||
self.doc_max_timesteps = doc_max_timesteps | |||||
cnt = 0 | |||||
with open(data_path, 'r') as reader: | |||||
for line in reader: | |||||
try: | |||||
e = json.loads(line) | |||||
article_sent = e['text'] | |||||
tag = e["tag"][0] if usetag else e['publication'] | |||||
# logger.info(tag) | |||||
if "duc" in data_path: | |||||
abstract_sent = e["summaryList"] if "summaryList" in e.keys() else [e['summary']] | |||||
else: | |||||
abstract_sent = e['summary'] | |||||
if domaindict: | |||||
if randomX == True: | |||||
p = np.random.rand() | |||||
if p <= self.randomP: | |||||
domainid = domaindict.domain2id("X") | |||||
else: | |||||
domainid = domaindict.domain2id(tag) | |||||
else: | |||||
domainid = domaindict.domain2id(tag) | |||||
else: | |||||
domainid = None | |||||
logger.debug((tag, domainid)) | |||||
except (ValueError,EOFError) as e : | |||||
logger.debug(e) | |||||
break | |||||
else: | |||||
example = Example(article_sent, abstract_sent, vocab, sent_max_len, e["label"], domainid) # Process into an Example. | |||||
self.example_list.append(example) | |||||
cnt += 1 | |||||
# print(cnt) | |||||
logger.info("[INFO] Finish reading ExampleSet. Total time is %f, Total size is %d", time.time() - start, len(self.example_list)) | |||||
self.size = len(self.example_list) | |||||
# self.example_list.sort(key=lambda ex: ex.domain) | |||||
def get_example(self, index): | |||||
return self.example_list[index] | |||||
def __getitem__(self, index): | |||||
""" | |||||
:param index: int; the index of the example | |||||
:return | |||||
input_pad: [N, seq_len] | |||||
label: [N] | |||||
input_mask: [N] | |||||
domain: [1] | |||||
""" | |||||
item = self.example_list[index] | |||||
input = np.array(item.enc_sent_input_pad) | |||||
label = np.array(item.label, dtype=int) | |||||
# pad input to doc_max_timesteps | |||||
if len(input) < self.doc_max_timesteps: | |||||
pad_number = self.doc_max_timesteps - len(input) | |||||
pad_matrix = np.zeros((pad_number, len(input[0]))) | |||||
input_pad = np.vstack((input, pad_matrix)) | |||||
label = np.append(label, np.zeros(pad_number, dtype=int)) | |||||
input_mask = np.append(np.ones(len(input)), np.zeros(pad_number)) | |||||
else: | |||||
input_pad = input[:self.doc_max_timesteps] | |||||
label = label[:self.doc_max_timesteps] | |||||
input_mask = np.ones(self.doc_max_timesteps) | |||||
if self.domaindict: | |||||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long(), item.domain | |||||
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long() | |||||
def __len__(self): | |||||
return self.size | |||||
class MultiExampleSet(): | |||||
def __init__(self, data_dir, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False): | |||||
self.datasets = [None] * (domaindict.size() - 1) | |||||
data_path_list = [os.path.join(data_dir, s) for s in os.listdir(data_dir) if s.endswith("label.jsonl")] | |||||
for data_path in data_path_list: | |||||
fname = data_path.split("/")[-1] # cnn.com.label.json | |||||
dataname = ".".join(fname.split(".")[:-2]) | |||||
domainid = domaindict.domain2id(dataname) | |||||
logger.info("[INFO] domain name: %s, domain id: %d" % (dataname, domainid)) | |||||
self.datasets[domainid - 1] = ExampleSet(data_path, vocab, doc_max_timesteps, sent_max_len, domaindict, randomX, usetag) | |||||
def get(self, id): | |||||
return self.datasets[id] | |||||
from torch.utils.data.dataloader import default_collate | |||||
def my_collate_fn(batch): | |||||
''' | |||||
:param batch: (input_pad, label, input_mask, domain) | |||||
:return: | |||||
''' | |||||
start_domain = batch[0][-1] | |||||
# for i in range(len(batch)): | |||||
# print(batch[i][-1], end=',') | |||||
batch = list(filter(lambda x: x[-1] == start_domain, batch)) | |||||
print("start_domain %d" % start_domain) | |||||
print("batch_len %d" % len(batch)) | |||||
if len(batch) == 0: return torch.Tensor() | |||||
return default_collate(batch) # 用默认方式拼接过滤后的batch数据 | |||||
@@ -0,0 +1,27 @@ | |||||
# -*- coding: utf-8 -*- | |||||
import logging | |||||
import sys | |||||
# 获取logger实例,如果参数为空则返回root logger | |||||
logger = logging.getLogger("Summarization logger") | |||||
# logger = logging.getLogger() | |||||
# 指定logger输出格式 | |||||
formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') | |||||
# # 文件日志 | |||||
# file_handler = logging.FileHandler("test.log") | |||||
# file_handler.setFormatter(formatter) # 可以通过setFormatter指定输出格式 | |||||
# 控制台日志 | |||||
console_handler = logging.StreamHandler(sys.stdout) | |||||
console_handler.formatter = formatter # 也可以直接给formatter赋值 | |||||
console_handler.setLevel(logging.INFO) | |||||
# 为logger添加的日志处理器 | |||||
# logger.addHandler(file_handler) | |||||
logger.addHandler(console_handler) | |||||
# 指定日志的最低输出级别,默认为WARN级别 | |||||
logger.setLevel(logging.DEBUG) |
@@ -0,0 +1,297 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
import re | |||||
import os | |||||
import shutil | |||||
import copy | |||||
import datetime | |||||
import numpy as np | |||||
from rouge import Rouge | |||||
from .logger import * | |||||
# from data import * | |||||
import sys | |||||
sys.setrecursionlimit(10000) | |||||
REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", | |||||
"-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} | |||||
def clean(x): | |||||
return re.sub( | |||||
r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", | |||||
lambda m: REMAP.get(m.group()), x) | |||||
def rouge_eval(hyps, refer): | |||||
rouge = Rouge() | |||||
# print(hyps) | |||||
# print(refer) | |||||
# print(rouge.get_scores(hyps, refer)) | |||||
try: | |||||
score = rouge.get_scores(hyps, refer)[0] | |||||
mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||||
except: | |||||
mean_score = 0.0 | |||||
return mean_score | |||||
def rouge_all(hyps, refer): | |||||
rouge = Rouge() | |||||
score = rouge.get_scores(hyps, refer)[0] | |||||
# mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) | |||||
return score | |||||
def eval_label(match_true, pred, true, total, match): | |||||
match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float() | |||||
try: | |||||
accu = match / total | |||||
precision = match_true / pred | |||||
recall = match_true / true | |||||
F = 2 * precision * recall / (precision + recall) | |||||
except ZeroDivisionError: | |||||
F = 0.0 | |||||
logger.error("[Error] float division by zero") | |||||
return accu, precision, recall, F | |||||
def pyrouge_score(hyps, refer, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
if remap == True: | |||||
refer = clean(refer) | |||||
hyps = clean(hyps) | |||||
system_file = os.path.join(SYSTEM_PATH, 'Reference.0.txt') | |||||
model_file = os.path.join(MODEL_PATH, 'Model.A.0.txt') | |||||
with open(system_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
with open(model_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
r = Rouge155('/home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def pyrouge_score_all(hyps_list, refer_list, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
assert len(hyps_list) == len(refer_list) | |||||
for i in range(len(hyps_list)): | |||||
system_file = os.path.join(SYSTEM_PATH, 'Reference.%d.txt' % i) | |||||
model_file = os.path.join(MODEL_PATH, 'Model.A.%d.txt' % i) | |||||
refer = clean(refer_list[i]) if remap else refer_list[i] | |||||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||||
with open(system_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
with open(model_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Reference.(\d+).txt' | |||||
r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def pyrouge_score_all_multi(hyps_list, refer_list, remap = True): | |||||
from pyrouge import Rouge155 | |||||
nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) | |||||
SYSTEM_PATH = os.path.join(PYROUGE_ROOT, 'system') | |||||
MODEL_PATH = os.path.join(PYROUGE_ROOT, 'gold') | |||||
if os.path.exists(SYSTEM_PATH): | |||||
shutil.rmtree(SYSTEM_PATH) | |||||
os.makedirs(SYSTEM_PATH) | |||||
if os.path.exists(MODEL_PATH): | |||||
shutil.rmtree(MODEL_PATH) | |||||
os.makedirs(MODEL_PATH) | |||||
assert len(hyps_list) == len(refer_list) | |||||
for i in range(len(hyps_list)): | |||||
system_file = os.path.join(SYSTEM_PATH, 'Model.%d.txt' % i) | |||||
# model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i) | |||||
hyps = clean(hyps_list[i]) if remap else hyps_list[i] | |||||
with open(system_file, 'wb') as f: | |||||
f.write(hyps.encode('utf-8')) | |||||
referType = ["A", "B", "C", "D", "E", "F", "G"] | |||||
for j in range(len(refer_list[i])): | |||||
model_file = os.path.join(MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i)) | |||||
refer = clean(refer_list[i][j]) if remap else refer_list[i][j] | |||||
with open(model_file, 'wb') as f: | |||||
f.write(refer.encode('utf-8')) | |||||
r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') | |||||
r.system_dir = SYSTEM_PATH | |||||
r.model_dir = MODEL_PATH | |||||
r.system_filename_pattern = 'Model.(\d+).txt' | |||||
r.model_filename_pattern = 'Reference.[A-Z].#ID#.txt' | |||||
output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") | |||||
output_dict = r.output_to_dict(output) | |||||
shutil.rmtree(PYROUGE_ROOT) | |||||
scores = {} | |||||
scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} | |||||
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] | |||||
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] | |||||
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] | |||||
return scores | |||||
def cal_label(article, abstract): | |||||
hyps_list = article | |||||
refer = abstract | |||||
scores = [] | |||||
for hyps in hyps_list: | |||||
mean_score = rouge_eval(hyps, refer) | |||||
scores.append(mean_score) | |||||
selected = [] | |||||
selected.append(int(np.argmax(scores))) | |||||
selected_sent_cnt = 1 | |||||
best_rouge = np.max(scores) | |||||
while selected_sent_cnt < len(hyps_list): | |||||
cur_max_rouge = 0.0 | |||||
cur_max_idx = -1 | |||||
for i in range(len(hyps_list)): | |||||
if i not in selected: | |||||
temp = copy.deepcopy(selected) | |||||
temp.append(i) | |||||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||||
cur_rouge = rouge_eval(hyps, refer) | |||||
if cur_rouge > cur_max_rouge: | |||||
cur_max_rouge = cur_rouge | |||||
cur_max_idx = i | |||||
if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge: | |||||
selected.append(cur_max_idx) | |||||
selected_sent_cnt += 1 | |||||
best_rouge = cur_max_rouge | |||||
else: | |||||
break | |||||
# label = np.zeros(len(hyps_list), dtype=int) | |||||
# label[np.array(selected)] = 1 | |||||
# return list(label) | |||||
return selected | |||||
def cal_label_limited3(article, abstract): | |||||
hyps_list = article | |||||
refer = abstract | |||||
scores = [] | |||||
for hyps in hyps_list: | |||||
try: | |||||
mean_score = rouge_eval(hyps, refer) | |||||
scores.append(mean_score) | |||||
except ValueError: | |||||
scores.append(0.0) | |||||
selected = [] | |||||
selected.append(np.argmax(scores)) | |||||
selected_sent_cnt = 1 | |||||
best_rouge = np.max(scores) | |||||
while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3: | |||||
cur_max_rouge = 0.0 | |||||
cur_max_idx = -1 | |||||
for i in range(len(hyps_list)): | |||||
if i not in selected: | |||||
temp = copy.deepcopy(selected) | |||||
temp.append(i) | |||||
hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) | |||||
cur_rouge = rouge_eval(hyps, refer) | |||||
if cur_rouge > cur_max_rouge: | |||||
cur_max_rouge = cur_rouge | |||||
cur_max_idx = i | |||||
selected.append(cur_max_idx) | |||||
selected_sent_cnt += 1 | |||||
best_rouge = cur_max_rouge | |||||
# logger.info(selected) | |||||
# label = np.zeros(len(hyps_list), dtype=int) | |||||
# label[np.array(selected)] = 1 | |||||
# return list(label) | |||||
return selected | |||||
import torch | |||||
def flip(x, dim): | |||||
xsize = x.size() | |||||
dim = x.dim() + dim if dim < 0 else dim | |||||
x = x.contiguous() | |||||
x = x.view(-1, *xsize[dim:]).contiguous() | |||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, | |||||
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] | |||||
return x.view(xsize) | |||||
def get_attn_key_pad_mask(seq_k, seq_q): | |||||
''' For masking out the padding part of key sequence. ''' | |||||
# Expand to fit the shape of key query attention matrix. | |||||
len_q = seq_q.size(1) | |||||
padding_mask = seq_k.eq(0.0) | |||||
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |||||
return padding_mask | |||||
def get_non_pad_mask(seq): | |||||
assert seq.dim() == 2 | |||||
return seq.ne(0.0).type(torch.float).unsqueeze(-1) |
@@ -0,0 +1,263 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import json | |||||
import argparse | |||||
import datetime | |||||
import torch | |||||
import torch.nn | |||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.trainer import Trainer, Tester | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
# from model.TransformerModel import TransformerModel | |||||
from model.TForiginal import TransformerModel | |||||
from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric | |||||
from model.Loss import MyCrossEntropyLoss | |||||
from tools.Callback import TrainCallback | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
lr = hps.lr | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | |||||
# criterion = torch.nn.CrossEntropyLoss(reduce="none") | |||||
trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | |||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
metric_key="f", validate_every=-1, save_path=eval_dir, | |||||
callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | |||||
train_info = trainer.train(load_best_model=True) | |||||
logger.info(' | end of Train | time: {:5.2f}s | '.format(train_info["seconds"])) | |||||
logger.info('[INFO] best eval model in epoch %d and iter %d', train_info["best_epoch"], train_info["best_step"]) | |||||
logger.info(train_info["best_eval"]) | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
if hps.use_pyrouge: | |||||
logger.info("[INFO] Use PyRougeMetric for testing") | |||||
tester = Tester(data=loader, model=model, | |||||
metrics=[LabelFMetric(pred="prediction"), PyRougeMetric(hps, pred="prediction")], | |||||
batch_size=hps.batch_size) | |||||
else: | |||||
logger.info("[INFO] Use FastRougeMetric for testing") | |||||
tester = Tester(data=loader, model=model, | |||||
metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
batch_size=hps.batch_size) | |||||
test_info = tester.test() | |||||
logger.info(test_info) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Summarization Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
# Important settings | |||||
parser.add_argument('--mode', choices=['train', 'test'], default='train', help='must be one of train/test') | |||||
parser.add_argument('--embedding', type=str, default='glove', choices=['word2vec', 'glove', 'elmo', 'bert'], help='must be one of word2vec/glove/elmo/bert') | |||||
parser.add_argument('--sentence_encoder', type=str, default='transformer', choices=['bilstm', 'deeplstm', 'transformer'], help='must be one of LSTM/Transformer') | |||||
parser.add_argument('--sentence_decoder', type=str, default='SeqLab', choices=['PN', 'SeqLab'], help='must be one of PN/SeqLab') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | |||||
# test | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if args.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
sum_loader = SummarizationLoader() | |||||
hps = args | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
if args.embedding == "glove": | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | |||||
if hps.word_embedding: | |||||
embed_loader = EmbedLoader() | |||||
pretrained_weight = embed_loader.load_with_vocab(hps.embedding_path, vocab) # unfound with random init | |||||
embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||||
embed.weight.requires_grad = hps.embed_train | |||||
else: | |||||
logger.error("[ERROR] embedding To Be Continued!") | |||||
sys.exit(1) | |||||
if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | |||||
model_param = json.load(open("config/transformer.config", "rb")) | |||||
hps.__dict__.update(model_param) | |||||
model = TransformerModel(hps, embed) | |||||
else: | |||||
logger.error("[ERROR] Model To Be Continued!") | |||||
sys.exit(1) | |||||
logger.info(hps) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
dataInfo.datasets["valid"].set_target("text", "summary") | |||||
setup_training(model, dataInfo.datasets["train"], dataInfo.datasets["valid"], hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
dataInfo.datasets["test"].set_target("text", "summary") | |||||
run_test(model, dataInfo.datasets["test"], hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,706 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import time | |||||
import copy | |||||
import pickle | |||||
import datetime | |||||
import argparse | |||||
import logging | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from rouge import Rouge | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.batch import DataSetIter | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from tools import utils | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
from model.TForiginal import TransformerModel | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
lr = hps.lr | |||||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||||
# eps=1e-09) | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
best_train_loss = None | |||||
best_train_F= None | |||||
best_loss = None | |||||
best_F = None | |||||
step_num = 0 | |||||
non_descent_cnt = 0 | |||||
for epoch in range(1, hps.n_epochs + 1): | |||||
epoch_loss = 0.0 | |||||
train_loss = 0.0 | |||||
total_example_num = 0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
epoch_start_time = time.time() | |||||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||||
# if i > 10: | |||||
# break | |||||
model.train() | |||||
iter_start_time=time.time() | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
# logger.info(batch_x["text"][0]) | |||||
# logger.info(input[0,:,:]) | |||||
# logger.info(input_len[0:5,:]) | |||||
# logger.info(batch_y["summary"][0:5]) | |||||
# logger.info(label[0:5,:]) | |||||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||||
batch_size, N, seq_len = input.size() | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
input = Variable(input) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs["p_sent"].view(-1, 2) | |||||
label = label.view(-1) | |||||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||||
# input_len = input_len.float().view(-1) | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.masked_fill(input_len.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
if hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||||
optimizer.step() | |||||
step_num += 1 | |||||
train_loss += float(loss.data) | |||||
epoch_loss += float(loss.data) | |||||
if i % 100 == 0: | |||||
# start debugger | |||||
# import pdb; pdb.set_trace() | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.debug(name) | |||||
logger.debug(param.grad.data.sum()) | |||||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||||
.format(i, (time.time() - iter_start_time), | |||||
float(train_loss / 100))) | |||||
train_loss = 0.0 | |||||
# calculate the precision, recall and F | |||||
prediction = outputs.max(1)[1] | |||||
prediction = prediction.data | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += int(batch_size * N) | |||||
if hps.lr_descent: | |||||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||||
# step_num * pow(hps.warmup_steps, -1.5)) | |||||
new_lr = max(5e-6, lr / (epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
epoch_avg_loss = epoch_loss / len(train_loader) | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||||
.format(epoch, (time.time() - epoch_start_time), | |||||
float(epoch_avg_loss))) | |||||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_loss = epoch_avg_loss | |||||
elif epoch_avg_loss > best_train_loss: | |||||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
if not best_train_F or F > best_train_F: | |||||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_F = F | |||||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||||
if non_descent_cnt >= 3: | |||||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
logger.info("[INFO] Starting eval for this model ...") | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
model.eval() | |||||
running_loss = 0.0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
total_example_num = 0 | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
iter_start_time = time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
# if i > 10: | |||||
# break | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input, requires_grad=False) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
outputs = model_outputs["p_sent"] | |||||
prediction = model_outputs["prediction"] | |||||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||||
label = label.view(-1) # [batch * N] | |||||
loss = criterion(outputs, label) | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.masked_fill(input_len.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
logger.debug("loss %f", loss) | |||||
running_loss += float(loss.data) | |||||
label = label.data.view(batch_size, -1) | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
# rouge | |||||
prediction = prediction.view(batch_size, -1) | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if sent_max_number < hps.m and len(hyps) <= 1: | |||||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
# logger.debug(prediction[j]) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("prediction:") | |||||
logger.debug(prediction[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
running_avg_loss = running_loss / len(loader) | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
logging.getLogger('global').setLevel(logging.WARNING) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# try: | |||||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# scores_all = [] | |||||
# for idx in range(len(pairs["hyps"])): | |||||
# try: | |||||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||||
# scores_all.append(scores) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||||
# finally: | |||||
# logger.error("During testing, some errors happen!") | |||||
# logger.error(len(scores_all)) | |||||
# exit(1) | |||||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||||
.format((time.time() - iter_start_time), | |||||
float(running_avg_loss))) | |||||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
logger.info(res) | |||||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||||
if best_loss is None or running_avg_loss < best_loss: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_loss is not None: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_loss = running_avg_loss | |||||
non_descent_cnt = 0 | |||||
else: | |||||
non_descent_cnt += 1 | |||||
if best_F is None or best_F < F: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_F is not None: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_F = F | |||||
return best_loss, best_F, non_descent_cnt | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "evalbestFmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "trainbestmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||||
elif hps.test_model == "trainbestFmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
import datetime | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||||
if hps.save_label: | |||||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||||
resfile = open(log_dir, "w") | |||||
else: | |||||
log_dir = os.path.join(test_dir, nowTime) | |||||
resfile = open(log_dir, "wb") | |||||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||||
model.eval() | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
total_example_num = 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
pred_list = [] | |||||
iter_start_time=time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
prediction = model_outputs["prediction"] | |||||
if hps.save_label: | |||||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||||
continue | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if limited: | |||||
k = len(refer.split()) | |||||
hyps = " ".join(hyps.split()[:k]) | |||||
logger.info((len(refer.split()),len(hyps.split()))) | |||||
resfile.write(b"Original_article:") | |||||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"Reference:") | |||||
if isinstance(refer, list): | |||||
for ref in refer: | |||||
resfile.write(ref.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b'*' * 40) | |||||
resfile.write(b"\n") | |||||
else: | |||||
resfile.write(refer.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"hypothesis:") | |||||
resfile.write(hyps.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
else: | |||||
try: | |||||
scores = utils.rouge_all(hyps, refer) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
except ValueError: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
# single example res writer | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||||
resfile.write(res.encode('utf-8')) | |||||
resfile.write(b'-' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.save_label: | |||||
import json | |||||
json.dump(pred_list, resfile) | |||||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||||
return | |||||
resfile.write(b"\n") | |||||
resfile.write(b'=' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# the whole model res writer | |||||
resfile.write(b"The total testset is:") | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
resfile.write(res.encode("utf-8")) | |||||
logger.info(res) | |||||
logger.info(' | end of test | time: {:5.2f}s | ' | |||||
.format((time.time() - iter_start_time))) | |||||
# label prediction | |||||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
resfile.write(res.encode('utf-8')) | |||||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Transformer Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
# Important settings | |||||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
hps = args | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if hps.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
logger.info(args) | |||||
logger.info(args) | |||||
sum_loader = SummarizationLoader() | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
model = TransformerModel(hps, vocab) | |||||
if len(hps.gpu) > 1: | |||||
gpuid = hps.gpu.split(',') | |||||
gpuid = [int(s) for s in gpuid] | |||||
model = nn.DataParallel(model,device_ids=gpuid) | |||||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
trainset = dataInfo.datasets["train"] | |||||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||||
train_batch = DataSetIter(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||||
validset = dataInfo.datasets["valid"] | |||||
validset.set_input("text", "summary") | |||||
valid_batch = DataSetIter(batch_size=hps.batch_size, dataset=validset) | |||||
setup_training(model, train_batch, valid_batch, hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
testset = dataInfo.datasets["test"] | |||||
testset.set_input("text", "summary") | |||||
test_batch = DataSetIter(batch_size=hps.batch_size, dataset=testset) | |||||
run_test(model, test_batch, hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,705 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Train Model1: baseline model""" | |||||
import os | |||||
import sys | |||||
import time | |||||
import copy | |||||
import pickle | |||||
import datetime | |||||
import argparse | |||||
import logging | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from rouge import Rouge | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from tools import utils | |||||
from tools.logger import * | |||||
from data.dataloader import SummarizationLoader | |||||
from model.TransformerModel import TransformerModel | |||||
def setup_training(model, train_loader, valid_loader, hps): | |||||
"""Does setup before starting training (run_training)""" | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if hps.restore_model != 'None': | |||||
logger.info("[INFO] Restoring %s for training...", hps.restore_model) | |||||
bestmodel_file = os.path.join(train_dir, hps.restore_model) | |||||
loader = ModelLoader() | |||||
loader.load_pytorch(model, bestmodel_file) | |||||
else: | |||||
logger.info("[INFO] Create new model for training...") | |||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
def run_training(model, train_loader, valid_loader, hps): | |||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
lr = hps.lr | |||||
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98), | |||||
# eps=1e-09) | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
best_train_loss = None | |||||
best_train_F= None | |||||
best_loss = None | |||||
best_F = None | |||||
step_num = 0 | |||||
non_descent_cnt = 0 | |||||
for epoch in range(1, hps.n_epochs + 1): | |||||
epoch_loss = 0.0 | |||||
train_loss = 0.0 | |||||
total_example_num = 0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
epoch_start_time = time.time() | |||||
for i, (batch_x, batch_y) in enumerate(train_loader): | |||||
# if i > 10: | |||||
# break | |||||
model.train() | |||||
iter_start_time=time.time() | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
# logger.info(batch_x["text"][0]) | |||||
# logger.info(input[0,:,:]) | |||||
# logger.info(input_len[0:5,:]) | |||||
# logger.info(batch_y["summary"][0:5]) | |||||
# logger.info(label[0:5,:]) | |||||
# logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0))) | |||||
batch_size, N, seq_len = input.size() | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
input = Variable(input) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs[Const.OUTPUT].view(-1, 2) | |||||
label = label.view(-1) | |||||
loss = criterion(outputs, label) # [batch_size, doc_max_timesteps] | |||||
input_len = input_len.float().view(-1) | |||||
loss = loss * input_len | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.sum(1).mean() | |||||
if not (np.isfinite(loss.data)).numpy(): | |||||
logger.error("train Loss is not finite. Stopping.") | |||||
logger.info(loss) | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.info(name) | |||||
logger.info(param.grad.data.sum()) | |||||
raise Exception("train Loss is not finite. Stopping.") | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
if hps.grad_clip: | |||||
torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm) | |||||
optimizer.step() | |||||
step_num += 1 | |||||
train_loss += float(loss.data) | |||||
epoch_loss += float(loss.data) | |||||
if i % 100 == 0: | |||||
# start debugger | |||||
# import pdb; pdb.set_trace() | |||||
for name, param in model.named_parameters(): | |||||
if param.requires_grad: | |||||
logger.debug(name) | |||||
logger.debug(param.grad.data.sum()) | |||||
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | ' | |||||
.format(i, (time.time() - iter_start_time), | |||||
float(train_loss / 100))) | |||||
train_loss = 0.0 | |||||
# calculate the precision, recall and F | |||||
prediction = outputs.max(1)[1] | |||||
prediction = prediction.data | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += int(batch_size * N) | |||||
if hps.lr_descent: | |||||
# new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5), | |||||
# step_num * pow(hps.warmup_steps, -1.5)) | |||||
new_lr = max(5e-6, lr / (epoch + 1)) | |||||
for param_group in list(optimizer.param_groups): | |||||
param_group['lr'] = new_lr | |||||
logger.info("[INFO] The learning rate now is %f", new_lr) | |||||
epoch_avg_loss = epoch_loss / len(train_loader) | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | ' | |||||
.format(epoch, (time.time() - epoch_start_time), | |||||
float(epoch_avg_loss))) | |||||
logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
if not best_train_loss or epoch_avg_loss < best_train_loss: | |||||
save_file = os.path.join(train_dir, "bestmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_loss = epoch_avg_loss | |||||
elif epoch_avg_loss > best_train_loss: | |||||
logger.error("[Error] training loss does not descent. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
if not best_train_F or F > best_train_F: | |||||
save_file = os.path.join(train_dir, "bestFmodel.pkl") | |||||
logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file) | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
best_train_F = F | |||||
best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt) | |||||
if non_descent_cnt >= 3: | |||||
logger.error("[Error] val loss does not descent for three times. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
return | |||||
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
logger.info("[INFO] Starting eval for this model ...") | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||||
model.eval() | |||||
running_loss = 0.0 | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
total_example_num = 0 | |||||
criterion = torch.nn.CrossEntropyLoss(reduction='none') | |||||
iter_start_time = time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
# if i > 10: | |||||
# break | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input, requires_grad=False) | |||||
label = Variable(label) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
outputs = model_outputs[Const.OUTPUTS] | |||||
prediction = model_outputs["prediction"] | |||||
outputs = outputs.view(-1, 2) # [batch * N, 2] | |||||
label = label.view(-1) # [batch * N] | |||||
loss = criterion(outputs, label) | |||||
input_len = input_len.float().view(-1) | |||||
loss = loss * input_len | |||||
loss = loss.view(batch_size, -1) | |||||
loss = loss.sum(1).mean() | |||||
running_loss += float(loss.data) | |||||
label = label.data | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
# rouge | |||||
prediction = prediction.view(batch_size, -1) | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if sent_max_number < hps.m and len(hyps) <= 1: | |||||
logger.error("sent_max_number is too short %d, Skip!" , sent_max_number) | |||||
continue | |||||
if len(hyps) >= 1 and hyps != '.': | |||||
# logger.debug(prediction[j]) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
elif refer == "." or refer == "": | |||||
logger.error("Refer is None!") | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
logger.debug(refer) | |||||
elif hyps == "." or hyps == "": | |||||
logger.error("hyps is None!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug("prediction:") | |||||
logger.debug(prediction[j]) | |||||
logger.debug(hyps) | |||||
else: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
running_avg_loss = running_loss / len(loader) | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
logging.getLogger('global').setLevel(logging.WARNING) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 : | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# try: | |||||
# scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# scores_all = [] | |||||
# for idx in range(len(pairs["hyps"])): | |||||
# try: | |||||
# scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0] | |||||
# scores_all.append(scores) | |||||
# except ValueError as e: | |||||
# logger.error(repr(e)) | |||||
# logger.debug("HYPS:\t%s", pairs["hyps"][idx]) | |||||
# logger.debug("REFER:\t%s", pairs["refer"][idx]) | |||||
# finally: | |||||
# logger.error("During testing, some errors happen!") | |||||
# logger.error(len(scores_all)) | |||||
# exit(1) | |||||
logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' | |||||
.format((time.time() - iter_start_time), | |||||
float(running_avg_loss))) | |||||
logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", | |||||
total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
logger.info(res) | |||||
# If running_avg_loss is best so far, save this checkpoint (early stopping). | |||||
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir | |||||
if best_loss is None or running_avg_loss < best_loss: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_loss is not None: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_loss = running_avg_loss | |||||
non_descent_cnt = 0 | |||||
else: | |||||
non_descent_cnt += 1 | |||||
if best_F is None or best_F < F: | |||||
bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved | |||||
if best_F is not None: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path) | |||||
else: | |||||
logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path) | |||||
saver = ModelSaver(bestmodel_save_path) | |||||
saver.save_pytorch(model) | |||||
best_F = F | |||||
return best_loss, best_F, non_descent_cnt | |||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||||
eval_dir = os.path.join(hps.save_root, "eval") | |||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||||
if not os.path.exists(eval_dir) : | |||||
logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir) | |||||
raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir)) | |||||
if hps.test_model == "evalbestmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved | |||||
elif hps.test_model == "evalbestFmodel": | |||||
bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "trainbestmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl') | |||||
elif hps.test_model == "trainbestFmodel": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl') | |||||
elif hps.test_model == "earlystop": | |||||
train_dir = os.path.join(hps.save_root, "train") | |||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl') | |||||
else: | |||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||||
modelloader = ModelLoader() | |||||
modelloader.load_pytorch(model, bestmodel_load_path) | |||||
import datetime | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在 | |||||
if hps.save_label: | |||||
log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1]) | |||||
resfile = open(log_dir, "w") | |||||
else: | |||||
log_dir = os.path.join(test_dir, nowTime) | |||||
resfile = open(log_dir, "wb") | |||||
logger.info("[INFO] Write the Evaluation into %s", log_dir) | |||||
model.eval() | |||||
match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0 | |||||
total_example_num = 0.0 | |||||
pairs = {} | |||||
pairs["hyps"] = [] | |||||
pairs["refer"] = [] | |||||
pred_list = [] | |||||
iter_start_time=time.time() | |||||
with torch.no_grad(): | |||||
for i, (batch_x, batch_y) in enumerate(loader): | |||||
input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN] | |||||
label = batch_y[Const.TARGET] | |||||
if hps.cuda: | |||||
input = input.cuda() # [batch, N, seq_len] | |||||
label = label.cuda() | |||||
input_len = input_len.cuda() | |||||
batch_size, N, _ = input.size() | |||||
input = Variable(input) | |||||
input_len = Variable(input_len, requires_grad=False) | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
prediction = model_outputs["pred"] | |||||
if hps.save_label: | |||||
pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist()) | |||||
continue | |||||
pred += prediction.sum() | |||||
true += label.sum() | |||||
match_true += ((prediction == label) & (prediction == 1)).sum() | |||||
match += (prediction == label).sum() | |||||
total_example_num += batch_size * N | |||||
for j in range(batch_size): | |||||
original_article_sents = batch_x["text"][j] | |||||
sent_max_number = len(original_article_sents) | |||||
refer = "\n".join(batch_x["summary"][j]) | |||||
hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number) | |||||
if limited: | |||||
k = len(refer.split()) | |||||
hyps = " ".join(hyps.split()[:k]) | |||||
logger.info((len(refer.split()),len(hyps.split()))) | |||||
resfile.write(b"Original_article:") | |||||
resfile.write("\n".join(batch_x["text"][j]).encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"Reference:") | |||||
if isinstance(refer, list): | |||||
for ref in refer: | |||||
resfile.write(ref.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b'*' * 40) | |||||
resfile.write(b"\n") | |||||
else: | |||||
resfile.write(refer.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
resfile.write(b"hypothesis:") | |||||
resfile.write(hyps.encode('utf-8')) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
else: | |||||
try: | |||||
scores = utils.rouge_all(hyps, refer) | |||||
pairs["hyps"].append(hyps) | |||||
pairs["refer"].append(refer) | |||||
except ValueError: | |||||
logger.error("Do not select any sentences!") | |||||
logger.debug("sent_max_number:%d", sent_max_number) | |||||
logger.debug(original_article_sents) | |||||
logger.debug("label:") | |||||
logger.debug(label[j]) | |||||
continue | |||||
# single example res writer | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']) | |||||
resfile.write(res.encode('utf-8')) | |||||
resfile.write(b'-' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.save_label: | |||||
import json | |||||
json.dump(pred_list, resfile) | |||||
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time))) | |||||
return | |||||
resfile.write(b"\n") | |||||
resfile.write(b'=' * 89) | |||||
resfile.write(b"\n") | |||||
if hps.use_pyrouge: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
if isinstance(pairs["refer"][0], list): | |||||
logger.info("Multi Reference summaries!") | |||||
scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"]) | |||||
else: | |||||
logger.info("The number of pairs is %d", len(pairs["hyps"])) | |||||
if not len(pairs["hyps"]): | |||||
logger.error("During testing, no hyps is selected!") | |||||
return | |||||
rouge = Rouge() | |||||
scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True) | |||||
# the whole model res writer | |||||
resfile.write(b"The total testset is:") | |||||
res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \ | |||||
+ "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \ | |||||
+ "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f']) | |||||
resfile.write(res.encode("utf-8")) | |||||
logger.info(res) | |||||
logger.info(' | end of test | time: {:5.2f}s | ' | |||||
.format((time.time() - iter_start_time))) | |||||
# label prediction | |||||
logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match) | |||||
accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match) | |||||
res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F) | |||||
resfile.write(res.encode('utf-8')) | |||||
logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description='Transformer Model') | |||||
# Where to find data | |||||
parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.') | |||||
parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.') | |||||
parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.') | |||||
parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.') | |||||
# Important settings | |||||
parser.add_argument('--mode', type=str, default='train', help='must be one of train/test') | |||||
parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||||
# Where to save output | |||||
parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.') | |||||
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.') | |||||
# Hyperparameters | |||||
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]') | |||||
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') | |||||
parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') | |||||
parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]') | |||||
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]') | |||||
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding') | |||||
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]') | |||||
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]') | |||||
parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]') | |||||
parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]') | |||||
parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel') | |||||
parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers') | |||||
parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]') | |||||
parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]') | |||||
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]') | |||||
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]') | |||||
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]') | |||||
parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]') | |||||
parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]') | |||||
parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)') | |||||
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)') | |||||
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention') | |||||
# Training | |||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization') | |||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
args = parser.parse_args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |||||
torch.set_printoptions(threshold=50000) | |||||
hps = args | |||||
# File paths | |||||
DATA_FILE = args.data_path | |||||
VALID_FILE = args.valid_path | |||||
VOCAL_FILE = args.vocab_path | |||||
LOG_PATH = args.log_root | |||||
# train_log setting | |||||
if not os.path.exists(LOG_PATH): | |||||
if hps.mode == "train": | |||||
os.makedirs(LOG_PATH) | |||||
else: | |||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||||
log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime) | |||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | |||||
logger.info(args) | |||||
logger.info(args) | |||||
sum_loader = SummarizationLoader() | |||||
if hps.mode == 'test': | |||||
paths = {"test": DATA_FILE} | |||||
hps.recurrent_dropout_prob = 0.0 | |||||
hps.atten_dropout_prob = 0.0 | |||||
hps.ffn_dropout_prob = 0.0 | |||||
logger.info(hps) | |||||
else: | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE)) | |||||
vocab = dataInfo.vocabs["vocab"] | |||||
model = TransformerModel(hps, vocab) | |||||
if len(hps.gpu) > 1: | |||||
gpuid = hps.gpu.split(',') | |||||
gpuid = [int(s) for s in gpuid] | |||||
model = nn.DataParallel(model,device_ids=gpuid) | |||||
logger.info("[INFO] Use Multi-gpu: %s", hps.gpu) | |||||
if hps.cuda: | |||||
model = model.cuda() | |||||
logger.info("[INFO] Use cuda") | |||||
if hps.mode == 'train': | |||||
trainset = dataInfo.datasets["train"] | |||||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | |||||
train_batch = Batch(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||||
validset = dataInfo.datasets["valid"] | |||||
validset.set_input("text", "summary") | |||||
valid_batch = Batch(batch_size=hps.batch_size, dataset=validset) | |||||
setup_training(model, train_batch, valid_batch, hps) | |||||
elif hps.mode == 'test': | |||||
logger.info("[INFO] Decoding...") | |||||
testset = dataInfo.datasets["test"] | |||||
testset.set_input("text", "summary") | |||||
test_batch = Batch(batch_size=hps.batch_size, dataset=testset) | |||||
run_test(model, test_batch, hps, limited=hps.limited) | |||||
else: | |||||
logger.error("The 'mode' flag must be one of train/eval/test") | |||||
raise ValueError("The 'mode' flag must be one of train/eval/test") | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,103 @@ | |||||
""" Manage beam search info structure. | |||||
Heavily borrowed from OpenNMT-py. | |||||
For code in OpenNMT-py, please check the following link: | |||||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py | |||||
""" | |||||
import torch | |||||
import numpy as np | |||||
import transformer.Constants as Constants | |||||
class Beam(): | |||||
''' Beam search ''' | |||||
def __init__(self, size, device=False): | |||||
self.size = size | |||||
self._done = False | |||||
# The score for each translation on the beam. | |||||
self.scores = torch.zeros((size,), dtype=torch.float, device=device) | |||||
self.all_scores = [] | |||||
# The backpointers at each time-step. | |||||
self.prev_ks = [] | |||||
# The outputs at each time-step. | |||||
self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] | |||||
self.next_ys[0][0] = Constants.BOS | |||||
def get_current_state(self): | |||||
"Get the outputs for the current timestep." | |||||
return self.get_tentative_hypothesis() | |||||
def get_current_origin(self): | |||||
"Get the backpointers for the current timestep." | |||||
return self.prev_ks[-1] | |||||
@property | |||||
def done(self): | |||||
return self._done | |||||
def advance(self, word_prob): | |||||
"Update beam status and check if finished or not." | |||||
num_words = word_prob.size(1) | |||||
# Sum the previous scores. | |||||
if len(self.prev_ks) > 0: | |||||
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) | |||||
else: | |||||
beam_lk = word_prob[0] | |||||
flat_beam_lk = beam_lk.view(-1) | |||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort | |||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort | |||||
self.all_scores.append(self.scores) | |||||
self.scores = best_scores | |||||
# bestScoresId is flattened as a (beam x word) array, | |||||
# so we need to calculate which word and beam each score came from | |||||
prev_k = best_scores_id / num_words | |||||
self.prev_ks.append(prev_k) | |||||
self.next_ys.append(best_scores_id - prev_k * num_words) | |||||
# End condition is when top-of-beam is EOS. | |||||
if self.next_ys[-1][0].item() == Constants.EOS: | |||||
self._done = True | |||||
self.all_scores.append(self.scores) | |||||
return self._done | |||||
def sort_scores(self): | |||||
"Sort the scores." | |||||
return torch.sort(self.scores, 0, True) | |||||
def get_the_best_score_and_idx(self): | |||||
"Get the score of the best in the beam." | |||||
scores, ids = self.sort_scores() | |||||
return scores[1], ids[1] | |||||
def get_tentative_hypothesis(self): | |||||
"Get the decoded sequence for the current timestep." | |||||
if len(self.next_ys) == 1: | |||||
dec_seq = self.next_ys[0].unsqueeze(1) | |||||
else: | |||||
_, keys = self.sort_scores() | |||||
hyps = [self.get_hypothesis(k) for k in keys] | |||||
hyps = [[Constants.BOS] + h for h in hyps] | |||||
dec_seq = torch.LongTensor(hyps) | |||||
return dec_seq | |||||
def get_hypothesis(self, k): | |||||
""" Walk back to construct the full hypothesis. """ | |||||
hyp = [] | |||||
for j in range(len(self.prev_ks) - 1, -1, -1): | |||||
hyp.append(self.next_ys[j+1][k]) | |||||
k = self.prev_ks[j][k] | |||||
return list(map(lambda x: x.item(), hyp[::-1])) |
@@ -0,0 +1,10 @@ | |||||
PAD = 0 | |||||
UNK = 1 | |||||
BOS = 2 | |||||
EOS = 3 | |||||
PAD_WORD = '<blank>' | |||||
UNK_WORD = '<unk>' | |||||
BOS_WORD = '<s>' | |||||
EOS_WORD = '</s>' |
@@ -0,0 +1,49 @@ | |||||
''' Define the Layers ''' | |||||
import torch.nn as nn | |||||
from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward | |||||
__author__ = "Yu-Hsiang Huang" | |||||
class EncoderLayer(nn.Module): | |||||
''' Compose with two layers ''' | |||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||||
super(EncoderLayer, self).__init__() | |||||
self.slf_attn = MultiHeadAttention( | |||||
n_head, d_model, d_k, d_v, dropout=dropout) | |||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): | |||||
enc_output, enc_slf_attn = self.slf_attn( | |||||
enc_input, enc_input, enc_input, mask=slf_attn_mask) | |||||
enc_output *= non_pad_mask | |||||
enc_output = self.pos_ffn(enc_output) | |||||
enc_output *= non_pad_mask | |||||
return enc_output, enc_slf_attn | |||||
class DecoderLayer(nn.Module): | |||||
''' Compose with three layers ''' | |||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||||
super(DecoderLayer, self).__init__() | |||||
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |||||
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |||||
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): | |||||
dec_output, dec_slf_attn = self.slf_attn( | |||||
dec_input, dec_input, dec_input, mask=slf_attn_mask) | |||||
dec_output *= non_pad_mask | |||||
dec_output, dec_enc_attn = self.enc_attn( | |||||
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) | |||||
dec_output *= non_pad_mask | |||||
dec_output = self.pos_ffn(dec_output) | |||||
dec_output *= non_pad_mask | |||||
return dec_output, dec_slf_attn, dec_enc_attn |
@@ -0,0 +1,208 @@ | |||||
''' Define the Transformer model ''' | |||||
import torch | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
import transformer.Constants as Constants | |||||
from transformer.Layers import EncoderLayer, DecoderLayer | |||||
__author__ = "Yu-Hsiang Huang" | |||||
def get_non_pad_mask(seq): | |||||
assert seq.dim() == 2 | |||||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | |||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
''' Sinusoid position encoding table ''' | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) | |||||
def get_attn_key_pad_mask(seq_k, seq_q): | |||||
''' For masking out the padding part of key sequence. ''' | |||||
# Expand to fit the shape of key query attention matrix. | |||||
len_q = seq_q.size(1) | |||||
padding_mask = seq_k.eq(Constants.PAD) | |||||
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |||||
return padding_mask | |||||
def get_subsequent_mask(seq): | |||||
''' For masking out the subsequent info. ''' | |||||
sz_b, len_s = seq.size() | |||||
subsequent_mask = torch.triu( | |||||
torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) | |||||
subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls | |||||
return subsequent_mask | |||||
class Encoder(nn.Module): | |||||
''' A encoder model with self attention mechanism. ''' | |||||
def __init__( | |||||
self, | |||||
n_src_vocab, len_max_seq, d_word_vec, | |||||
n_layers, n_head, d_k, d_v, | |||||
d_model, d_inner, dropout=0.1): | |||||
super().__init__() | |||||
n_position = len_max_seq + 1 | |||||
self.src_word_emb = nn.Embedding( | |||||
n_src_vocab, d_word_vec, padding_idx=Constants.PAD) | |||||
self.position_enc = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), | |||||
freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||||
for _ in range(n_layers)]) | |||||
def forward(self, src_seq, src_pos, return_attns=False): | |||||
enc_slf_attn_list = [] | |||||
# -- Prepare masks | |||||
slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) | |||||
non_pad_mask = get_non_pad_mask(src_seq) | |||||
# -- Forward | |||||
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) | |||||
for enc_layer in self.layer_stack: | |||||
enc_output, enc_slf_attn = enc_layer( | |||||
enc_output, | |||||
non_pad_mask=non_pad_mask, | |||||
slf_attn_mask=slf_attn_mask) | |||||
if return_attns: | |||||
enc_slf_attn_list += [enc_slf_attn] | |||||
if return_attns: | |||||
return enc_output, enc_slf_attn_list | |||||
return enc_output, | |||||
class Decoder(nn.Module): | |||||
''' A decoder model with self attention mechanism. ''' | |||||
def __init__( | |||||
self, | |||||
n_tgt_vocab, len_max_seq, d_word_vec, | |||||
n_layers, n_head, d_k, d_v, | |||||
d_model, d_inner, dropout=0.1): | |||||
super().__init__() | |||||
n_position = len_max_seq + 1 | |||||
self.tgt_word_emb = nn.Embedding( | |||||
n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) | |||||
self.position_enc = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), | |||||
freeze=True) | |||||
self.layer_stack = nn.ModuleList([ | |||||
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||||
for _ in range(n_layers)]) | |||||
def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): | |||||
dec_slf_attn_list, dec_enc_attn_list = [], [] | |||||
# -- Prepare masks | |||||
non_pad_mask = get_non_pad_mask(tgt_seq) | |||||
slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) | |||||
slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) | |||||
slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) | |||||
dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) | |||||
# -- Forward | |||||
dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) | |||||
for dec_layer in self.layer_stack: | |||||
dec_output, dec_slf_attn, dec_enc_attn = dec_layer( | |||||
dec_output, enc_output, | |||||
non_pad_mask=non_pad_mask, | |||||
slf_attn_mask=slf_attn_mask, | |||||
dec_enc_attn_mask=dec_enc_attn_mask) | |||||
if return_attns: | |||||
dec_slf_attn_list += [dec_slf_attn] | |||||
dec_enc_attn_list += [dec_enc_attn] | |||||
if return_attns: | |||||
return dec_output, dec_slf_attn_list, dec_enc_attn_list | |||||
return dec_output, | |||||
class Transformer(nn.Module): | |||||
''' A sequence to sequence model with attention mechanism. ''' | |||||
def __init__( | |||||
self, | |||||
n_src_vocab, n_tgt_vocab, len_max_seq, | |||||
d_word_vec=512, d_model=512, d_inner=2048, | |||||
n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, | |||||
tgt_emb_prj_weight_sharing=True, | |||||
emb_src_tgt_weight_sharing=True): | |||||
super().__init__() | |||||
self.encoder = Encoder( | |||||
n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, | |||||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, | |||||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, | |||||
dropout=dropout) | |||||
self.decoder = Decoder( | |||||
n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, | |||||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, | |||||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, | |||||
dropout=dropout) | |||||
self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) | |||||
nn.init.xavier_normal_(self.tgt_word_prj.weight) | |||||
assert d_model == d_word_vec, \ | |||||
'To facilitate the residual connections, \ | |||||
the dimensions of all module outputs shall be the same.' | |||||
if tgt_emb_prj_weight_sharing: | |||||
# Share the weight matrix between target word embedding & the final logit dense layer | |||||
self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight | |||||
self.x_logit_scale = (d_model ** -0.5) | |||||
else: | |||||
self.x_logit_scale = 1. | |||||
if emb_src_tgt_weight_sharing: | |||||
# Share the weight matrix between source & target word embeddings | |||||
assert n_src_vocab == n_tgt_vocab, \ | |||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | |||||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | |||||
tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] | |||||
enc_output, *_ = self.encoder(src_seq, src_pos) | |||||
dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) | |||||
seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale | |||||
return seq_logit.view(-1, seq_logit.size(2)) |
@@ -0,0 +1,28 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
__author__ = "Yu-Hsiang Huang" | |||||
class ScaledDotProductAttention(nn.Module): | |||||
''' Scaled Dot-Product Attention ''' | |||||
def __init__(self, temperature, attn_dropout=0.1): | |||||
super().__init__() | |||||
self.temperature = temperature | |||||
self.dropout = nn.Dropout(attn_dropout) | |||||
self.softmax = nn.Softmax(dim=2) | |||||
def forward(self, q, k, v, mask=None): | |||||
attn = torch.bmm(q, k.transpose(1, 2)) | |||||
attn = attn / self.temperature | |||||
if mask is not None: | |||||
attn = attn.masked_fill(mask, -np.inf) | |||||
attn = self.softmax(attn) | |||||
attn = self.dropout(attn) | |||||
output = torch.bmm(attn, v) | |||||
return output, attn |
@@ -0,0 +1,35 @@ | |||||
'''A wrapper class for optimizer ''' | |||||
import numpy as np | |||||
class ScheduledOptim(): | |||||
'''A simple wrapper class for learning rate scheduling''' | |||||
def __init__(self, optimizer, d_model, n_warmup_steps): | |||||
self._optimizer = optimizer | |||||
self.n_warmup_steps = n_warmup_steps | |||||
self.n_current_steps = 0 | |||||
self.init_lr = np.power(d_model, -0.5) | |||||
def step_and_update_lr(self): | |||||
"Step with the inner optimizer" | |||||
self._update_learning_rate() | |||||
self._optimizer.step() | |||||
def zero_grad(self): | |||||
"Zero out the gradients by the inner optimizer" | |||||
self._optimizer.zero_grad() | |||||
def _get_lr_scale(self): | |||||
return np.min([ | |||||
np.power(self.n_current_steps, -0.5), | |||||
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) | |||||
def _update_learning_rate(self): | |||||
''' Learning rate scheduling per step ''' | |||||
self.n_current_steps += 1 | |||||
lr = self.init_lr * self._get_lr_scale() | |||||
for param_group in self._optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
@@ -0,0 +1,82 @@ | |||||
''' Define the sublayers in encoder/decoder layer ''' | |||||
import numpy as np | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from transformer.Modules import ScaledDotProductAttention | |||||
__author__ = "Yu-Hsiang Huang" | |||||
class MultiHeadAttention(nn.Module): | |||||
''' Multi-Head Attention module ''' | |||||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |||||
super().__init__() | |||||
self.n_head = n_head | |||||
self.d_k = d_k | |||||
self.d_v = d_v | |||||
self.w_qs = nn.Linear(d_model, n_head * d_k) | |||||
self.w_ks = nn.Linear(d_model, n_head * d_k) | |||||
self.w_vs = nn.Linear(d_model, n_head * d_v) | |||||
nn.init.xavier_normal_(self.w_qs.weight) | |||||
nn.init.xavier_normal_(self.w_ks.weight) | |||||
nn.init.xavier_normal_(self.w_vs.weight) | |||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) | |||||
self.layer_norm = nn.LayerNorm(d_model) | |||||
self.fc = nn.Linear(n_head * d_v, d_model) | |||||
nn.init.xavier_normal_(self.fc.weight) | |||||
self.dropout = nn.Dropout(dropout) | |||||
def forward(self, q, k, v, mask=None): | |||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |||||
sz_b, len_q, _ = q.size() | |||||
sz_b, len_k, _ = k.size() | |||||
sz_b, len_v, _ = v.size() | |||||
residual = q | |||||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) | |||||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |||||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk | |||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk | |||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv | |||||
if mask is not None: | |||||
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. | |||||
output, attn = self.attention(q, k, v, mask=mask) | |||||
output = output.view(n_head, sz_b, len_q, d_v) | |||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) | |||||
output = self.dropout(self.fc(output)) | |||||
output = self.layer_norm(output + residual) | |||||
return output, attn | |||||
class PositionwiseFeedForward(nn.Module): | |||||
''' A two-feed-forward-layer module ''' | |||||
def __init__(self, d_in, d_hid, dropout=0.1): | |||||
super().__init__() | |||||
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise | |||||
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise | |||||
self.layer_norm = nn.LayerNorm(d_in) | |||||
self.dropout = nn.Dropout(dropout) | |||||
def forward(self, x): | |||||
residual = x | |||||
output = x.transpose(1, 2) | |||||
output = self.w_2(F.relu(self.w_1(output))) | |||||
output = output.transpose(1, 2) | |||||
output = self.dropout(output) | |||||
output = self.layer_norm(output + residual) | |||||
return output |
@@ -0,0 +1,166 @@ | |||||
''' This module will handle the text generation with beam search. ''' | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from transformer.Models import Transformer | |||||
from transformer.Beam import Beam | |||||
class Translator(object): | |||||
''' Load with trained model and handle the beam search ''' | |||||
def __init__(self, opt): | |||||
self.opt = opt | |||||
self.device = torch.device('cuda' if opt.cuda else 'cpu') | |||||
checkpoint = torch.load(opt.model) | |||||
model_opt = checkpoint['settings'] | |||||
self.model_opt = model_opt | |||||
model = Transformer( | |||||
model_opt.src_vocab_size, | |||||
model_opt.tgt_vocab_size, | |||||
model_opt.max_token_seq_len, | |||||
tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, | |||||
emb_src_tgt_weight_sharing=model_opt.embs_share_weight, | |||||
d_k=model_opt.d_k, | |||||
d_v=model_opt.d_v, | |||||
d_model=model_opt.d_model, | |||||
d_word_vec=model_opt.d_word_vec, | |||||
d_inner=model_opt.d_inner_hid, | |||||
n_layers=model_opt.n_layers, | |||||
n_head=model_opt.n_head, | |||||
dropout=model_opt.dropout) | |||||
model.load_state_dict(checkpoint['model']) | |||||
print('[Info] Trained model state loaded.') | |||||
model.word_prob_prj = nn.LogSoftmax(dim=1) | |||||
model = model.to(self.device) | |||||
self.model = model | |||||
self.model.eval() | |||||
def translate_batch(self, src_seq, src_pos): | |||||
''' Translation work in one batch ''' | |||||
def get_inst_idx_to_tensor_position_map(inst_idx_list): | |||||
''' Indicate the position of an instance in a tensor. ''' | |||||
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} | |||||
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): | |||||
''' Collect tensor parts associated to active instances. ''' | |||||
_, *d_hs = beamed_tensor.size() | |||||
n_curr_active_inst = len(curr_active_inst_idx) | |||||
new_shape = (n_curr_active_inst * n_bm, *d_hs) | |||||
beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) | |||||
beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) | |||||
beamed_tensor = beamed_tensor.view(*new_shape) | |||||
return beamed_tensor | |||||
def collate_active_info( | |||||
src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): | |||||
# Sentences which are still active are collected, | |||||
# so the decoder will not run on completed sentences. | |||||
n_prev_active_inst = len(inst_idx_to_position_map) | |||||
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] | |||||
active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) | |||||
active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) | |||||
active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) | |||||
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) | |||||
return active_src_seq, active_src_enc, active_inst_idx_to_position_map | |||||
def beam_decode_step( | |||||
inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): | |||||
''' Decode and update beam status, and then return active beam idx ''' | |||||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): | |||||
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] | |||||
dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) | |||||
dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) | |||||
return dec_partial_seq | |||||
def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): | |||||
dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) | |||||
dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) | |||||
return dec_partial_pos | |||||
def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): | |||||
dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) | |||||
dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h | |||||
word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) | |||||
word_prob = word_prob.view(n_active_inst, n_bm, -1) | |||||
return word_prob | |||||
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): | |||||
active_inst_idx_list = [] | |||||
for inst_idx, inst_position in inst_idx_to_position_map.items(): | |||||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) | |||||
if not is_inst_complete: | |||||
active_inst_idx_list += [inst_idx] | |||||
return active_inst_idx_list | |||||
n_active_inst = len(inst_idx_to_position_map) | |||||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) | |||||
dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) | |||||
word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) | |||||
# Update the beam with predicted word prob information and collect incomplete instances | |||||
active_inst_idx_list = collect_active_inst_idx_list( | |||||
inst_dec_beams, word_prob, inst_idx_to_position_map) | |||||
return active_inst_idx_list | |||||
def collect_hypothesis_and_scores(inst_dec_beams, n_best): | |||||
all_hyp, all_scores = [], [] | |||||
for inst_idx in range(len(inst_dec_beams)): | |||||
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() | |||||
all_scores += [scores[:n_best]] | |||||
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] | |||||
all_hyp += [hyps] | |||||
return all_hyp, all_scores | |||||
with torch.no_grad(): | |||||
#-- Encode | |||||
src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) | |||||
src_enc, *_ = self.model.encoder(src_seq, src_pos) | |||||
#-- Repeat data for beam search | |||||
n_bm = self.opt.beam_size | |||||
n_inst, len_s, d_h = src_enc.size() | |||||
src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) | |||||
src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) | |||||
#-- Prepare beams | |||||
inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] | |||||
#-- Bookkeeping for active or not | |||||
active_inst_idx_list = list(range(n_inst)) | |||||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) | |||||
#-- Decode | |||||
for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): | |||||
active_inst_idx_list = beam_decode_step( | |||||
inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) | |||||
if not active_inst_idx_list: | |||||
break # all instances have finished their path to <EOS> | |||||
src_seq, src_enc, inst_idx_to_position_map = collate_active_info( | |||||
src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) | |||||
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) | |||||
return batch_hyp, batch_scores |
@@ -0,0 +1,13 @@ | |||||
import transformer.Constants | |||||
import transformer.Modules | |||||
import transformer.Layers | |||||
import transformer.SubLayers | |||||
import transformer.Models | |||||
import transformer.Translator | |||||
import transformer.Beam | |||||
import transformer.Optim | |||||
__all__ = [ | |||||
transformer.Constants, transformer.Modules, transformer.Layers, | |||||
transformer.SubLayers, transformer.Models, transformer.Optim, | |||||
transformer.Translator, transformer.Beam] |
@@ -0,0 +1,129 @@ | |||||
import os | |||||
import torch | |||||
import sys | |||||
from torch import nn | |||||
from fastNLP.core.callback import Callback | |||||
from fastNLP.core.utils import _get_model_device | |||||
class MyCallback(Callback): | |||||
def __init__(self, args): | |||||
super(MyCallback, self).__init__() | |||||
self.args = args | |||||
self.real_step = 0 | |||||
def on_step_end(self): | |||||
if self.step % self.update_every == 0 and self.step > 0: | |||||
self.real_step += 1 | |||||
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5)) | |||||
for param_group in self.optimizer.param_groups: | |||||
param_group['lr'] = cur_lr | |||||
if self.real_step % 1000 == 0: | |||||
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step)) | |||||
def on_epoch_end(self): | |||||
self.pbar.write('Epoch {} is done !!!'.format(self.epoch)) | |||||
def _save_model(model, model_name, save_dir, only_param=False): | |||||
""" 存储不含有显卡信息的 state_dict 或 model | |||||
:param model: | |||||
:param model_name: | |||||
:param save_dir: 保存的 directory | |||||
:param only_param: | |||||
:return: | |||||
""" | |||||
model_path = os.path.join(save_dir, model_name) | |||||
if not os.path.isdir(save_dir): | |||||
os.makedirs(save_dir, exist_ok=True) | |||||
if isinstance(model, nn.DataParallel): | |||||
model = model.module | |||||
if only_param: | |||||
state_dict = model.state_dict() | |||||
for key in state_dict: | |||||
state_dict[key] = state_dict[key].cpu() | |||||
torch.save(state_dict, model_path) | |||||
else: | |||||
_model_device = _get_model_device(model) | |||||
model.cpu() | |||||
torch.save(model, model_path) | |||||
model.to(_model_device) | |||||
class SaveModelCallback(Callback): | |||||
""" | |||||
由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。 | |||||
会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型 | |||||
-save_dir | |||||
-2019-07-03-15-06-36 | |||||
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
-epoch1step40 | |||||
-2019-07-03-15-10-00 | |||||
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型 | |||||
:param bool only_param: 是否只保存模型权重 | |||||
:param save_on_exception: 发生exception时,是否保存一份当时的模型 | |||||
""" | |||||
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False): | |||||
super().__init__() | |||||
if not os.path.isdir(save_dir): | |||||
raise IsADirectoryError("{} is not a directory.".format(save_dir)) | |||||
self.save_dir = save_dir | |||||
if top < 0: | |||||
self.top = sys.maxsize | |||||
else: | |||||
self.top = top | |||||
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 | |||||
self.only_param = only_param | |||||
self.save_on_exception = save_on_exception | |||||
def on_train_begin(self): | |||||
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
metric_value = list(eval_result.values())[0][metric_key] | |||||
self._save_this_model(metric_value) | |||||
def _insert_into_ordered_save_models(self, pair): | |||||
# pair:(metric_value, model_name) | |||||
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 | |||||
index = -1 | |||||
for _pair in self._ordered_save_models: | |||||
if _pair[0]>=pair[0] and self.trainer.increase_better: | |||||
break | |||||
if not self.trainer.increase_better and _pair[0]<=pair[0]: | |||||
break | |||||
index += 1 | |||||
save_pair = None | |||||
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1): | |||||
save_pair = pair | |||||
self._ordered_save_models.insert(index+1, pair) | |||||
delete_pair = None | |||||
if len(self._ordered_save_models)>self.top: | |||||
delete_pair = self._ordered_save_models.pop(0) | |||||
return save_pair, delete_pair | |||||
def _save_this_model(self, metric_value): | |||||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | |||||
if save_pair: | |||||
try: | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
except Exception as e: | |||||
print(f"The following exception:{e} happens when saves model to {self.save_dir}.") | |||||
if delete_pair: | |||||
try: | |||||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||||
if os.path.exists(delete_model_path): | |||||
os.remove(delete_model_path) | |||||
except Exception as e: | |||||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||||
def on_exception(self, exception): | |||||
if self.save_on_exception: | |||||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
@@ -0,0 +1,157 @@ | |||||
from time import time | |||||
from datetime import timedelta | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.modules.encoder._bert import BertTokenizer | |||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.core.const import Const | |||||
class BertData(JsonLoader): | |||||
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512): | |||||
fields = {'article': 'article', | |||||
'label': 'label'} | |||||
super(BertData, self).__init__(fields=fields) | |||||
self.max_nsents = max_nsents | |||||
self.max_ntokens = max_ntokens | |||||
self.max_len = max_len | |||||
self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||||
self.cls_id = self.tokenizer.vocab['[CLS]'] | |||||
self.sep_id = self.tokenizer.vocab['[SEP]'] | |||||
self.pad_id = self.tokenizer.vocab['[PAD]'] | |||||
def _load(self, paths): | |||||
dataset = super(BertData, self)._load(paths) | |||||
return dataset | |||||
def process(self, paths): | |||||
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): | |||||
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']] | |||||
return article[:max_nsents] | |||||
def truncate_labels(instance): | |||||
label = list(filter(lambda x: x < len(instance['article']), instance['label'])) | |||||
return label | |||||
def bert_tokenize(instance, tokenizer, max_len, pad_value): | |||||
article = instance['article'] | |||||
article = ' [SEP] [CLS] '.join(article) | |||||
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] | |||||
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] | |||||
token_ids = tokenizer.convert_tokens_to_ids(word_pieces) | |||||
while len(token_ids) < max_len: | |||||
token_ids.append(pad_value) | |||||
assert len(token_ids) == max_len | |||||
return token_ids | |||||
def get_seg_id(instance, max_len, sep_id): | |||||
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] | |||||
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] | |||||
segment_id = [] | |||||
for i, length in enumerate(segs): | |||||
if i % 2 == 0: | |||||
segment_id += length * [0] | |||||
else: | |||||
segment_id += length * [1] | |||||
while len(segment_id) < max_len: | |||||
segment_id.append(0) | |||||
return segment_id | |||||
def get_cls_id(instance, cls_id): | |||||
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id] | |||||
return classification_id | |||||
def get_labels(instance): | |||||
labels = [0] * len(instance['cls_id']) | |||||
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) | |||||
for idx in label_idx: | |||||
labels[idx] = 1 | |||||
return labels | |||||
datasets = {} | |||||
for name in paths: | |||||
datasets[name] = self._load(paths[name]) | |||||
# remove empty samples | |||||
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) | |||||
# truncate articles | |||||
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article') | |||||
# truncate labels | |||||
datasets[name].apply(truncate_labels, new_field_name='label') | |||||
# tokenize and convert tokens to id | |||||
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') | |||||
# get segment id | |||||
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') | |||||
# get classification id | |||||
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') | |||||
# get label | |||||
datasets[name].apply(get_labels, new_field_name='label') | |||||
# rename filed | |||||
datasets[name].rename_field('article', Const.INPUTS(0)) | |||||
datasets[name].rename_field('segment_id', Const.INPUTS(1)) | |||||
datasets[name].rename_field('cls_id', Const.INPUTS(2)) | |||||
datasets[name].rename_field('lbael', Const.TARGET) | |||||
# set input and target | |||||
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) | |||||
datasets[name].set_target(Const.TARGET) | |||||
# set paddding value | |||||
datasets[name].set_pad_val('article', 0) | |||||
return DataInfo(datasets=datasets) | |||||
class BertSumLoader(JsonLoader): | |||||
def __init__(self): | |||||
fields = {'article': 'article', | |||||
'segment_id': 'segment_id', | |||||
'cls_id': 'cls_id', | |||||
'label': Const.TARGET | |||||
} | |||||
super(BertSumLoader, self).__init__(fields=fields) | |||||
def _load(self, paths): | |||||
dataset = super(BertSumLoader, self)._load(paths) | |||||
return dataset | |||||
def process(self, paths): | |||||
def get_seq_len(instance): | |||||
return len(instance['article']) | |||||
print('Start loading datasets !!!') | |||||
start = time() | |||||
# load datasets | |||||
datasets = {} | |||||
for name in paths: | |||||
datasets[name] = self._load(paths[name]) | |||||
datasets[name].apply(get_seq_len, new_field_name='seq_len') | |||||
# set input and target | |||||
datasets[name].set_input('article', 'segment_id', 'cls_id') | |||||
datasets[name].set_target(Const.TARGET) | |||||
# set padding value | |||||
datasets[name].set_pad_val('article', 0) | |||||
datasets[name].set_pad_val('segment_id', 0) | |||||
datasets[name].set_pad_val('cls_id', -1) | |||||
datasets[name].set_pad_val(Const.TARGET, 0) | |||||
print('Finished in {}'.format(timedelta(seconds=time()-start))) | |||||
return DataInfo(datasets=datasets) |
@@ -0,0 +1,178 @@ | |||||
import numpy as np | |||||
import json | |||||
from os.path import join | |||||
import torch | |||||
import logging | |||||
import tempfile | |||||
import subprocess as sp | |||||
from datetime import timedelta | |||||
from time import time | |||||
from pyrouge import Rouge155 | |||||
from pyrouge.utils import log | |||||
from fastNLP.core.losses import LossBase | |||||
from fastNLP.core.metrics import MetricBase | |||||
_ROUGE_PATH = '/path/to/RELEASE-1.5.5' | |||||
class MyBCELoss(LossBase): | |||||
def __init__(self, pred=None, target=None, mask=None): | |||||
super(MyBCELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||||
def get_loss(self, pred, target, mask): | |||||
loss = self.loss_func(pred, target.float()) | |||||
loss = (loss * mask.float()).sum() | |||||
return loss | |||||
class LossMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None, mask=None): | |||||
super(LossMetric, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||||
self.avg_loss = 0.0 | |||||
self.nsamples = 0 | |||||
def evaluate(self, pred, target, mask): | |||||
batch_size = pred.size(0) | |||||
loss = self.loss_func(pred, target.float()) | |||||
loss = (loss * mask.float()).sum() | |||||
self.avg_loss += loss | |||||
self.nsamples += batch_size | |||||
def get_metric(self, reset=True): | |||||
self.avg_loss = self.avg_loss / self.nsamples | |||||
eval_result = {'loss': self.avg_loss} | |||||
if reset: | |||||
self.avg_loss = 0 | |||||
self.nsamples = 0 | |||||
return eval_result | |||||
class RougeMetric(MetricBase): | |||||
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None): | |||||
super(RougeMetric, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.data_path = data_path | |||||
self.dec_path = dec_path | |||||
self.ref_path = ref_path | |||||
self.n_total = n_total | |||||
self.n_ext = n_ext | |||||
self.ngram_block = ngram_block | |||||
self.cur_idx = 0 | |||||
self.ext = [] | |||||
self.start = time() | |||||
@staticmethod | |||||
def eval_rouge(dec_dir, ref_dir): | |||||
assert _ROUGE_PATH is not None | |||||
log.get_global_console_logger().setLevel(logging.WARNING) | |||||
dec_pattern = '(\d+).dec' | |||||
ref_pattern = '#ID#.ref' | |||||
cmd = '-c 95 -r 1000 -n 2 -m' | |||||
with tempfile.TemporaryDirectory() as tmp_dir: | |||||
Rouge155.convert_summaries_to_rouge_format( | |||||
dec_dir, join(tmp_dir, 'dec')) | |||||
Rouge155.convert_summaries_to_rouge_format( | |||||
ref_dir, join(tmp_dir, 'ref')) | |||||
Rouge155.write_config_static( | |||||
join(tmp_dir, 'dec'), dec_pattern, | |||||
join(tmp_dir, 'ref'), ref_pattern, | |||||
join(tmp_dir, 'settings.xml'), system_id=1 | |||||
) | |||||
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl') | |||||
+ ' -e {} '.format(join(_ROUGE_PATH, 'data')) | |||||
+ cmd | |||||
+ ' -a {}'.format(join(tmp_dir, 'settings.xml'))) | |||||
output = sp.check_output(cmd.split(' '), universal_newlines=True) | |||||
R_1 = float(output.split('\n')[3].split(' ')[3]) | |||||
R_2 = float(output.split('\n')[7].split(' ')[3]) | |||||
R_L = float(output.split('\n')[11].split(' ')[3]) | |||||
print(output) | |||||
return R_1, R_2, R_L | |||||
def evaluate(self, pred, target, mask): | |||||
pred = pred + mask.float() | |||||
pred = pred.cpu().data.numpy() | |||||
ext_ids = np.argsort(-pred, 1) | |||||
for sent_id in ext_ids: | |||||
self.ext.append(sent_id) | |||||
self.cur_idx += 1 | |||||
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( | |||||
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start)) | |||||
), end='') | |||||
def get_metric(self, use_ngram_block=True, reset=True): | |||||
def check_n_gram(sentence, n, dic): | |||||
tokens = sentence.split(' ') | |||||
s_len = len(tokens) | |||||
for i in range(s_len): | |||||
if i + n > s_len: | |||||
break | |||||
if ' '.join(tokens[i: i + n]) in dic: | |||||
return False | |||||
return True # no n_gram overlap | |||||
# load original data | |||||
data = [] | |||||
with open(self.data_path) as f: | |||||
for line in f: | |||||
cur_data = json.loads(line) | |||||
if 'text' in cur_data: | |||||
new_data = {} | |||||
new_data['article'] = cur_data['text'] | |||||
new_data['abstract'] = cur_data['summary'] | |||||
data.append(new_data) | |||||
else: | |||||
data.append(cur_data) | |||||
# write decode sentences and references | |||||
if use_ngram_block == True: | |||||
print('\nStart {}-gram blocking !!!'.format(self.ngram_block)) | |||||
for i, ext_ids in enumerate(self.ext): | |||||
dec, ref = [], [] | |||||
if use_ngram_block == False: | |||||
n_sent = min(len(data[i]['article']), self.n_ext) | |||||
for j in range(n_sent): | |||||
idx = ext_ids[j] | |||||
dec.append(data[i]['article'][idx]) | |||||
else: | |||||
n_sent = len(ext_ids) | |||||
dic = {} | |||||
for j in range(n_sent): | |||||
sent = data[i]['article'][ext_ids[j]] | |||||
if check_n_gram(sent, self.ngram_block, dic) == True: | |||||
dec.append(sent) | |||||
# update dic | |||||
tokens = sent.split(' ') | |||||
s_len = len(tokens) | |||||
for k in range(s_len): | |||||
if k + self.ngram_block > s_len: | |||||
break | |||||
dic[' '.join(tokens[k: k + self.ngram_block])] = 1 | |||||
if len(dec) >= self.n_ext: | |||||
break | |||||
for sent in data[i]['abstract']: | |||||
ref.append(sent) | |||||
with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f: | |||||
for sent in dec: | |||||
print(sent, file=f) | |||||
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f: | |||||
for sent in ref: | |||||
print(sent, file=f) | |||||
print('\nStart evaluating ROUGE score !!!') | |||||
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path) | |||||
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L} | |||||
if reset == True: | |||||
self.cur_idx = 0 | |||||
self.ext = [] | |||||
self.start = time() | |||||
return eval_result |
@@ -0,0 +1,51 @@ | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import init | |||||
from fastNLP.modules.encoder._bert import BertModel | |||||
class Classifier(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(Classifier, self).__init__() | |||||
self.linear = nn.Linear(hidden_size, 1) | |||||
self.sigmoid = nn.Sigmoid() | |||||
def forward(self, inputs, mask_cls): | |||||
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len] | |||||
sent_scores = self.sigmoid(h) * mask_cls.float() | |||||
return sent_scores | |||||
class BertSum(nn.Module): | |||||
def __init__(self, hidden_size=768): | |||||
super(BertSum, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||||
self.decoder = Classifier(self.hidden_size) | |||||
def forward(self, article, segment_id, cls_id): | |||||
# print(article.device) | |||||
# print(segment_id.device) | |||||
# print(cls_id.device) | |||||
input_mask = 1 - (article == 0) | |||||
mask_cls = 1 - (cls_id == -1) | |||||
assert input_mask.size() == article.size() | |||||
assert mask_cls.size() == cls_id.size() | |||||
bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask) | |||||
bert_out = bert_out[0][-1] # last layer | |||||
sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id] | |||||
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float() | |||||
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] | |||||
sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len] | |||||
assert sent_scores.size() == (article.size(0), cls_id.size(1)) | |||||
return {'pred': sent_scores, 'mask': mask_cls} |
@@ -0,0 +1,147 @@ | |||||
import sys | |||||
import argparse | |||||
import os | |||||
import json | |||||
import torch | |||||
from time import time | |||||
from datetime import timedelta | |||||
from os.path import join, exists | |||||
from torch.optim import Adam | |||||
from utils import get_data_path, get_rouge_path | |||||
from dataloader import BertSumLoader | |||||
from model import BertSum | |||||
from fastNLP.core.optimizer import AdamW | |||||
from metrics import MyBCELoss, LossMetric, RougeMetric | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from callback import MyCallback, SaveModelCallback | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.core.tester import Tester | |||||
def configure_training(args): | |||||
devices = [int(gpu) for gpu in args.gpus.split(',')] | |||||
params = {} | |||||
params['label_type'] = args.label_type | |||||
params['batch_size'] = args.batch_size | |||||
params['accum_count'] = args.accum_count | |||||
params['max_lr'] = args.max_lr | |||||
params['warmup_steps'] = args.warmup_steps | |||||
params['n_epochs'] = args.n_epochs | |||||
params['valid_steps'] = args.valid_steps | |||||
return devices, params | |||||
def train_model(args): | |||||
# check if the data_path and save_path exists | |||||
data_paths = get_data_path(args.mode, args.label_type) | |||||
for name in data_paths: | |||||
assert exists(data_paths[name]) | |||||
if not exists(args.save_path): | |||||
os.makedirs(args.save_path) | |||||
# load summarization datasets | |||||
datasets = BertSumLoader().process(data_paths) | |||||
print('Information of dataset is:') | |||||
print(datasets) | |||||
train_set = datasets.datasets['train'] | |||||
valid_set = datasets.datasets['val'] | |||||
# configure training | |||||
devices, train_params = configure_training(args) | |||||
with open(join(args.save_path, 'params.json'), 'w') as f: | |||||
json.dump(train_params, f, indent=4) | |||||
print('Devices is:') | |||||
print(devices) | |||||
# configure model | |||||
model = BertSum() | |||||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0) | |||||
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)] | |||||
criterion = MyBCELoss() | |||||
val_metric = [LossMetric()] | |||||
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size) | |||||
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer, | |||||
loss=criterion, batch_size=args.batch_size, # sampler=sampler, | |||||
update_every=args.accum_count, n_epochs=args.n_epochs, | |||||
print_every=100, dev_data=valid_set, metrics=val_metric, | |||||
metric_key='-loss', validate_every=args.valid_steps, | |||||
save_path=args.save_path, device=devices, callbacks=callbacks) | |||||
print('Start training with the following hyper-parameters:') | |||||
print(train_params) | |||||
trainer.train() | |||||
def test_model(args): | |||||
models = os.listdir(args.save_path) | |||||
# load dataset | |||||
data_paths = get_data_path(args.mode, args.label_type) | |||||
datasets = BertSumLoader().process(data_paths) | |||||
print('Information of dataset is:') | |||||
print(datasets) | |||||
test_set = datasets.datasets['test'] | |||||
# only need 1 gpu for testing | |||||
device = int(args.gpus) | |||||
args.batch_size = 1 | |||||
for cur_model in models: | |||||
print('Current model is {}'.format(cur_model)) | |||||
# load model | |||||
model = torch.load(join(args.save_path, cur_model)) | |||||
# configure testing | |||||
original_path, dec_path, ref_path = get_rouge_path(args.label_type) | |||||
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path, | |||||
ref_path=ref_path, n_total = len(test_set)) | |||||
tester = Tester(data=test_set, model=model, metrics=[test_metric], | |||||
batch_size=args.batch_size, device=device) | |||||
tester.test() | |||||
if __name__ == '__main__': | |||||
parser = argparse.ArgumentParser( | |||||
description='training/testing of BertSum(liu et al. 2019)' | |||||
) | |||||
parser.add_argument('--mode', required=True, | |||||
help='training or testing of BertSum', type=str) | |||||
parser.add_argument('--label_type', default='greedy', | |||||
help='greedy/limit', type=str) | |||||
parser.add_argument('--save_path', required=True, | |||||
help='root of the model', type=str) | |||||
# example for gpus input: '0,1,2,3' | |||||
parser.add_argument('--gpus', required=True, | |||||
help='available gpus for training(separated by commas)', type=str) | |||||
parser.add_argument('--batch_size', default=18, | |||||
help='the training batch size', type=int) | |||||
parser.add_argument('--accum_count', default=2, | |||||
help='number of updates steps to accumulate before performing a backward/update pass.', type=int) | |||||
parser.add_argument('--max_lr', default=2e-5, | |||||
help='max learning rate for warm up', type=float) | |||||
parser.add_argument('--warmup_steps', default=10000, | |||||
help='warm up steps for training', type=int) | |||||
parser.add_argument('--n_epochs', default=10, | |||||
help='total number of training epochs', type=int) | |||||
parser.add_argument('--valid_steps', default=1000, | |||||
help='number of update steps for checkpoint and validation', type=int) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
print('Training process of BertSum !!!') | |||||
train_model(args) | |||||
else: | |||||
print('Testing process of BertSum !!!') | |||||
test_model(args) | |||||
@@ -0,0 +1,24 @@ | |||||
import os | |||||
from os.path import exists | |||||
def get_data_path(mode, label_type): | |||||
paths = {} | |||||
if mode == 'train': | |||||
paths['train'] = 'data/' + label_type + '/bert.train.jsonl' | |||||
paths['val'] = 'data/' + label_type + '/bert.val.jsonl' | |||||
else: | |||||
paths['test'] = 'data/' + label_type + '/bert.test.jsonl' | |||||
return paths | |||||
def get_rouge_path(label_type): | |||||
if label_type == 'others': | |||||
data_path = 'data/' + label_type + '/bert.test.jsonl' | |||||
else: | |||||
data_path = 'data/' + label_type + '/test.jsonl' | |||||
dec_path = 'dec' | |||||
ref_path = 'ref' | |||||
if not exists(ref_path): | |||||
os.makedirs(ref_path) | |||||
if not exists(dec_path): | |||||
os.makedirs(dec_path) | |||||
return data_path, dec_path, ref_path |