@@ -62,7 +62,7 @@ class ExtCNNDMPipe(Pipe): | |||||
db.set_input(Const.INPUT, Const.INPUT_LEN) | db.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
db.set_target(Const.TARGET, Const.INPUT_LEN) | db.set_target(Const.TARGET, Const.INPUT_LEN) | ||||
print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||||
word_list = [] | word_list = [] | ||||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | ||||
cnt = 2 # pad and unk | cnt = 2 # pad and unk | ||||
@@ -1,188 +0,0 @@ | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.core.const import Const | |||||
from tools.logger import * | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class SummarizationLoader(JsonLoader): | |||||
""" | |||||
读取summarization数据集,读取的DataSet包含fields:: | |||||
text: list(str),document | |||||
summary: list(str), summary | |||||
text_wd: list(list(str)),tokenized document | |||||
summary_wd: list(list(str)), tokenized summary | |||||
labels: list(int), | |||||
flatten_label: list(int), 0 or 1, flatten labels | |||||
domain: str, optional | |||||
tag: list(str), optional | |||||
数据来源: CNN_DailyMail Newsroom DUC | |||||
""" | |||||
def __init__(self): | |||||
super(SummarizationLoader, self).__init__() | |||||
def _load(self, path): | |||||
ds = super(SummarizationLoader, self)._load(path) | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||||
return ds | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||||
""" | |||||
:param paths: dict path for each dataset | |||||
:param vocab_size: int max_size for vocab | |||||
:param vocab_path: str vocab path | |||||
:param sent_max_len: int max token number of the sentence | |||||
:param doc_max_timesteps: int max sentence number of the document | |||||
:param domain: bool build vocab for publication, use 'X' for unknown | |||||
:param tag: bool build vocab for tag, use 'X' for unknown | |||||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||||
:return: DataBundle | |||||
datasets: dict keys correspond to the paths dict | |||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||||
embeddings: optional | |||||
""" | |||||
def _pad_sent(text_wd): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
datasets = {} | |||||
train_ds = None | |||||
for key, value in paths.items(): | |||||
ds = self.load(value) | |||||
# pad sent | |||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||||
# pad document | |||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||||
# rename field | |||||
ds.rename_field("pad_text", Const.INPUT) | |||||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||||
ds.rename_field("pad_label", Const.TARGET) | |||||
# set input and target | |||||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
datasets[key] = ds | |||||
if "train" in key: | |||||
train_ds = datasets[key] | |||||
vocab_dict = {} | |||||
if load_vocab_file == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | |||||
if train_ds == None: | |||||
raise ValueError("Lack train file to build vocabulary!") | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||||
vocab_dict["vocab"] = vocabs | |||||
else: | |||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||||
word_list = [] | |||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
vocab_dict["vocab"] = vocabs | |||||
if domain == True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(train_ds, field_name="publication") | |||||
vocab_dict["domain"] = domaindict | |||||
if tag == True: | |||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||||
tagdict.from_dataset(train_ds, field_name="tag") | |||||
vocab_dict["tag"] = tagdict | |||||
for ds in datasets.values(): | |||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||||
@@ -94,6 +94,8 @@ class Encoder(nn.Module): | |||||
if self._hps.cuda: | if self._hps.cuda: | ||||
input_pos = input_pos.cuda() | input_pos = input_pos.cuda() | ||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | ||||
# print(enc_embed_input.size()) | |||||
# print(enc_pos_embed_input.size()) | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | enc_conv_input = enc_embed_input + enc_pos_embed_input | ||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | ||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | ||||
@@ -17,11 +17,12 @@ class SummarizationModel(nn.Module): | |||||
""" | """ | ||||
:param hps: hyperparameters for the model | :param hps: hyperparameters for the model | ||||
:param vocab: vocab object | |||||
:param embed: word embedding | |||||
""" | """ | ||||
super(SummarizationModel, self).__init__() | super(SummarizationModel, self).__init__() | ||||
self._hps = hps | self._hps = hps | ||||
self.Train = (hps.mode == 'train') | |||||
# sentence encoder | # sentence encoder | ||||
self.encoder = Encoder(hps, embed) | self.encoder = Encoder(hps, embed) | ||||
@@ -45,18 +46,19 @@ class SummarizationModel(nn.Module): | |||||
self.wh = nn.Linear(self.d_v, 2) | self.wh = nn.Linear(self.d_v, 2) | ||||
def forward(self, input, input_len, Train): | |||||
def forward(self, words, seq_len): | |||||
""" | """ | ||||
:param input: [batch_size, N, seq_len], word idx long tensor | :param input: [batch_size, N, seq_len], word idx long tensor | ||||
:param input_len: [batch_size, N], 1 for sentence and 0 for padding | :param input_len: [batch_size, N], 1 for sentence and 0 for padding | ||||
:param Train: True for train and False for eval and test | |||||
:param return_atten: True or False to return multi-head attention output self.output_slf_attn | |||||
:return: | :return: | ||||
p_sent: [batch_size, N, 2] | p_sent: [batch_size, N, 2] | ||||
output_slf_attn: (option) [n_head, batch_size, N, N] | output_slf_attn: (option) [n_head, batch_size, N, N] | ||||
""" | """ | ||||
input = words | |||||
input_len = seq_len | |||||
# -- Sentence Encoder | # -- Sentence Encoder | ||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | ||||
@@ -67,7 +69,7 @@ class SummarizationModel(nn.Module): | |||||
self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] | self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] | ||||
self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) | self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) | ||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size] | |||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train=self.train) # [batch, N, hidden_size] | |||||
# -- Prepare masks | # -- Prepare masks | ||||
batch_size, N = input_len.size() | batch_size, N = input_len.size() | ||||
@@ -21,7 +21,7 @@ import torch | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.losses import LossBase | from fastNLP.core.losses import LossBase | ||||
from tools.logger import * | |||||
from fastNLP.core._logger import logger | |||||
class MyCrossEntropyLoss(LossBase): | class MyCrossEntropyLoss(LossBase): | ||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | ||||
@@ -20,14 +20,60 @@ from __future__ import division | |||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from rouge import Rouge | from rouge import Rouge | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.metrics import MetricBase | from fastNLP.core.metrics import MetricBase | ||||
from tools.logger import * | |||||
# from tools.logger import * | |||||
from fastNLP.core._logger import logger | |||||
from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | ||||
class LossMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.padding_idx = padding_idx | |||||
self.reduce = reduce | |||||
self.loss = 0.0 | |||||
self.iteration = 0 | |||||
def evaluate(self, pred, target, mask): | |||||
""" | |||||
:param pred: [batch, N, 2] | |||||
:param target: [batch, N] | |||||
:param input_mask: [batch, N] | |||||
:return: | |||||
""" | |||||
batch, N, _ = pred.size() | |||||
pred = pred.view(-1, 2) | |||||
target = target.view(-1) | |||||
loss = F.cross_entropy(input=pred, target=target, | |||||
ignore_index=self.padding_idx, reduction=self.reduce) | |||||
loss = loss.view(batch, -1) | |||||
loss = loss.masked_fill(mask.eq(0), 0) | |||||
loss = loss.sum(1).mean() | |||||
self.loss += loss | |||||
self.iteration += 1 | |||||
def get_metric(self, reset=True): | |||||
epoch_avg_loss = self.loss / self.iteration | |||||
if reset: | |||||
self.loss = 0.0 | |||||
self.iteration = 0 | |||||
metric = {"loss": -epoch_avg_loss} | |||||
logger.info(metric) | |||||
return metric | |||||
class LabelFMetric(MetricBase): | class LabelFMetric(MetricBase): | ||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
super().__init__() | super().__init__() | ||||
@@ -51,7 +51,7 @@ class TransformerModel(nn.Module): | |||||
ffn_inner_hidden_size: FFN hiddens size | ffn_inner_hidden_size: FFN hiddens size | ||||
atten_dropout_prob: dropout size | atten_dropout_prob: dropout size | ||||
doc_max_timesteps: max sentence number of the document | doc_max_timesteps: max sentence number of the document | ||||
:param vocab: | |||||
:param embed: word embedding | |||||
""" | """ | ||||
super(TransformerModel, self).__init__() | super(TransformerModel, self).__init__() | ||||
@@ -28,7 +28,7 @@ from fastNLP.core.const import Const | |||||
from fastNLP.io.model_io import ModelSaver | from fastNLP.io.model_io import ModelSaver | ||||
from fastNLP.core.callback import Callback, EarlyStopError | from fastNLP.core.callback import Callback, EarlyStopError | ||||
from tools.logger import * | |||||
from fastNLP.core._logger import logger | |||||
class TrainCallback(Callback): | class TrainCallback(Callback): | ||||
def __init__(self, hps, patience=3, quit_all=True): | def __init__(self, hps, patience=3, quit_all=True): | ||||
@@ -36,6 +36,9 @@ class TrainCallback(Callback): | |||||
self._hps = hps | self._hps = hps | ||||
self.patience = patience | self.patience = patience | ||||
self.wait = 0 | self.wait = 0 | ||||
self.train_loss = 0.0 | |||||
self.prev_train_avg_loss = 1000.0 | |||||
self.train_dir = os.path.join(self._hps.save_root, "train") | |||||
if type(quit_all) != bool: | if type(quit_all) != bool: | ||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | ||||
@@ -43,20 +46,7 @@ class TrainCallback(Callback): | |||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
self.epoch_start_time = time.time() | self.epoch_start_time = time.time() | ||||
# def on_loss_begin(self, batch_y, predict_y): | |||||
# """ | |||||
# | |||||
# :param batch_y: dict | |||||
# input_len: [batch, N] | |||||
# :param predict_y: dict | |||||
# p_sent: [batch, N, 2] | |||||
# :return: | |||||
# """ | |||||
# input_len = batch_y[Const.INPUT_LEN] | |||||
# batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) | |||||
# # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) | |||||
# # logger.debug(predict_y["p_sent"][0:5,:,:]) | |||||
self.model.Train = True | |||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
""" | """ | ||||
@@ -72,19 +62,34 @@ class TrainCallback(Callback): | |||||
logger.info(name) | logger.info(name) | ||||
logger.info(param.grad.data.sum()) | logger.info(param.grad.data.sum()) | ||||
raise Exception("train Loss is not finite. Stopping.") | raise Exception("train Loss is not finite. Stopping.") | ||||
self.train_loss += loss.data | |||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self._hps.grad_clip: | if self._hps.grad_clip: | ||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | ||||
torch.cuda.empty_cache() | |||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' | |||||
.format(self.epoch, (time.time() - self.epoch_start_time))) | |||||
epoch_avg_loss = self.train_loss / self.n_steps | |||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | train loss: {:5.6f}' | |||||
.format(self.epoch, (time.time() - self.epoch_start_time), epoch_avg_loss)) | |||||
if self.prev_train_avg_loss < epoch_avg_loss: | |||||
save_file = os.path.join(self.train_dir, "earlystop.pkl") | |||||
self.save_model(save_file) | |||||
else: | |||||
self.prev_train_avg_loss = epoch_avg_loss | |||||
self.train_loss = 0.0 | |||||
# save epoch | |||||
save_file = os.path.join(self.train_dir, "epoch_%d.pkl" % self.epoch) | |||||
self.save_model(save_file) | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
self.valid_start_time = time.time() | self.valid_start_time = time.time() | ||||
self.model.Train = False | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | ||||
@@ -95,9 +100,7 @@ class TrainCallback(Callback): | |||||
if self.wait == self.patience: | if self.wait == self.patience: | ||||
train_dir = os.path.join(self._hps.save_root, "train") | train_dir = os.path.join(self._hps.save_root, "train") | ||||
save_file = os.path.join(train_dir, "earlystop.pkl") | save_file = os.path.join(train_dir, "earlystop.pkl") | ||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
self.save_model(save_file) | |||||
raise EarlyStopError("Early stopping raised.") | raise EarlyStopError("Early stopping raised.") | ||||
else: | else: | ||||
self.wait += 1 | self.wait += 1 | ||||
@@ -111,14 +114,12 @@ class TrainCallback(Callback): | |||||
param_group['lr'] = new_lr | param_group['lr'] = new_lr | ||||
logger.info("[INFO] The learning rate now is %f", new_lr) | logger.info("[INFO] The learning rate now is %f", new_lr) | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if isinstance(exception, KeyboardInterrupt): | if isinstance(exception, KeyboardInterrupt): | ||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | ||||
train_dir = os.path.join(self._hps.save_root, "train") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
save_file = os.path.join(self.train_dir, "earlystop.pkl") | |||||
self.save_model(save_file) | |||||
if self.quit_all is True: | if self.quit_all is True: | ||||
sys.exit(0) # 直接退出程序 | sys.exit(0) # 直接退出程序 | ||||
@@ -127,6 +128,11 @@ class TrainCallback(Callback): | |||||
else: | else: | ||||
raise exception # 抛出陌生Error | raise exception # 抛出陌生Error | ||||
def save_model(self, save_file): | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(self.model) | |||||
logger.info('[INFO] Saving model to %s', save_file) | |||||
@@ -1,562 +0,0 @@ | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import * | |||||
import torch.nn.init as init | |||||
import data | |||||
from tools.logger import * | |||||
from transformer.Models import get_sinusoid_encoding_table | |||||
class Encoder(nn.Module): | |||||
def __init__(self, hps, vocab): | |||||
super(Encoder, self).__init__() | |||||
self._hps = hps | |||||
self._vocab = vocab | |||||
self.sent_max_len = hps.sent_max_len | |||||
vocab_size = len(vocab) | |||||
logger.info("[INFO] Vocabulary size is %d", vocab_size) | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]')) | |||||
if hps.word_embedding: | |||||
word2vec = data.Word_Embedding(hps.embedding_path, vocab) | |||||
word_vecs = word2vec.load_my_vecs(embed_size) | |||||
# pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size) | |||||
pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size) | |||||
pretrained_weight = np.array(pretrained_weight) | |||||
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||||
self.embed.weight.requires_grad = hps.embed_train | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len] | |||||
vocab = self._vocab | |||||
batch_size, N, _ = input.size() | |||||
input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||||
input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1] | |||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class DomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(DomainEncoder, self).__init__(hps, vocab) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
sent_embedding = super().forward(input) | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class MultiDomainEncoder(Encoder): | |||||
def __init__(self, hps, vocab, domaindict): | |||||
super(MultiDomainEncoder, self).__init__(hps, vocab) | |||||
self.domain_size = domaindict.size() | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, input, domain): | |||||
""" | |||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||||
:param domain: [batch_size, domain_size] | |||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||||
""" | |||||
batch_size, N, _ = input.size() | |||||
# logger.info(domain[:5, :]) | |||||
sent_embedding = super().forward(input) | |||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, hps): | |||||
super(BertEncoder, self).__init__() | |||||
from pytorch_pretrained_bert.modeling import BertModel | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
self._cuda = hps.cuda | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# word embedding | |||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||||
self._bert.eval() | |||||
for p in self._bert.parameters(): | |||||
p.requires_grad = False | |||||
self.word_embedding_proj = nn.Linear(4096, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
for i in range(batch_size): | |||||
tokens_id = inputs[i] | |||||
input_mask = input_masks[i] | |||||
sent_len = enc_sent_len[i] | |||||
input_ids = tokens_id.unsqueeze(0) | |||||
input_mask = input_mask.unsqueeze(0) | |||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 1 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 511: | |||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class BertTagEncoder(BertEncoder): | |||||
def __init__(self, hps, domaindict): | |||||
super(BertTagEncoder, self).__init__(hps) | |||||
# domain embedding | |||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||||
self.domain_embedding.weight.requires_grad = True | |||||
def forward(self, inputs, input_masks, enc_sent_len, domain): | |||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||||
batch_size, N = enc_sent_len.size() | |||||
enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||||
return sent_embedding | |||||
class ELMoEndoer(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer, self).__init__() | |||||
self._hps = hps | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def forward(self, input): | |||||
# input: a batch of Example object [batch_size, N, seq_len, character_len] | |||||
batch_size, N, seq_len, _ = input.size() | |||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||||
logger.debug(input_sent_len.view(batch_size, -1)) | |||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||||
enc_embed_input = self.embed_proj(enc_embed_input) | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in input_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding | |||||
class ELMoEndoer2(nn.Module): | |||||
def __init__(self, hps): | |||||
super(ELMoEndoer2, self).__init__() | |||||
self._hps = hps | |||||
self._cuda = hps.cuda | |||||
self.sent_max_len = hps.sent_max_len | |||||
from allennlp.modules.elmo import Elmo | |||||
elmo_dim = 1024 | |||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||||
# elmo_dim = 512 | |||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||||
embed_size = hps.word_emb_dim | |||||
sent_max_len = hps.sent_max_len | |||||
input_channels = 1 | |||||
out_channels = hps.output_channel | |||||
min_kernel_size = hps.min_kernel_size | |||||
max_kernel_size = hps.max_kernel_size | |||||
width = embed_size | |||||
# elmo embedding | |||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||||
self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||||
# position embedding | |||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||||
# cnn | |||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||||
logger.info("[INFO] Initing W for CNN.......") | |||||
for conv in self.convs: | |||||
init_weight_value = 6.0 | |||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||||
def calculate_fan_in_and_fan_out(tensor): | |||||
dimensions = tensor.ndimension() | |||||
if dimensions < 2: | |||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||||
if dimensions == 2: # Linear | |||||
fan_in = tensor.size(1) | |||||
fan_out = tensor.size(0) | |||||
else: | |||||
num_input_fmaps = tensor.size(1) | |||||
num_output_fmaps = tensor.size(0) | |||||
receptive_field_size = 1 | |||||
if tensor.dim() > 2: | |||||
receptive_field_size = tensor[0][0].numel() | |||||
fan_in = num_input_fmaps * receptive_field_size | |||||
fan_out = num_output_fmaps * receptive_field_size | |||||
return fan_in, fan_out | |||||
def pad_encoder_input(self, input_list): | |||||
""" | |||||
:param input_list: N [seq_len, hidden_state] | |||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||||
""" | |||||
max_len = self.sent_max_len | |||||
enc_sent_input_pad = [] | |||||
_, hidden_size = input_list[0].size() | |||||
for i in range(len(input_list)): | |||||
article_words = input_list[i] # [seq_len, hidden_size] | |||||
seq_len = article_words.size(0) | |||||
if seq_len > max_len: | |||||
pad_words = article_words[:max_len, :] | |||||
else: | |||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||||
pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||||
enc_sent_input_pad.append(pad_words) | |||||
return enc_sent_input_pad | |||||
def forward(self, inputs, input_masks, enc_sent_len): | |||||
""" | |||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||||
:param input_masks: 0 or 1, [batch, doc_len=512] | |||||
:param enc_sent_len: sentence original length [batch, N] | |||||
:return: | |||||
sent_embedding: [batch, N, D] | |||||
""" | |||||
# Use Bert to get word embedding | |||||
batch_size, N = enc_sent_len.size() | |||||
input_pad_list = [] | |||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||||
# print("END elmo") | |||||
for i in range(batch_size): | |||||
sent_len = enc_sent_len[i] # [1, N] | |||||
out = elmo_output[i] | |||||
_, hidden_size = out.size() | |||||
# restore the sentence | |||||
last_end = 0 | |||||
enc_sent_input = [] | |||||
for length in sent_len: | |||||
if length != 0 and last_end < 512: | |||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||||
last_end += length | |||||
else: | |||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||||
enc_sent_input.append(pad_tensor) | |||||
# pad the sentence | |||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||||
input_pad = torch.stack(input_pad_list) | |||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||||
sent_pos_list = [] | |||||
for sentlen in enc_sent_len: | |||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||||
for k in range(self.sent_max_len - sentlen): | |||||
sent_pos.append(0) | |||||
sent_pos_list.append(sent_pos) | |||||
input_pos = torch.Tensor(sent_pos_list).long() | |||||
if self._hps.cuda: | |||||
input_pos = input_pos.cuda() | |||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||||
enc_conv_input = enc_embed_input + enc_pos_embed_input | |||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||||
sent_embedding = sent_embedding.view(batch_size, N, -1) | |||||
return sent_embedding |
@@ -21,6 +21,7 @@ | |||||
import os | import os | ||||
import sys | import sys | ||||
import json | import json | ||||
import shutil | |||||
import argparse | import argparse | ||||
import datetime | import datetime | ||||
@@ -32,20 +33,25 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | ||||
from fastNLP.core._logger import logger | |||||
# from fastNLP.core._logger import _init_logger | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.trainer import Trainer, Tester | from fastNLP.core.trainer import Trainer, Tester | ||||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe | from fastNLP.io.pipe.summarization import ExtCNNDMPipe | ||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | from fastNLP.io.model_io import ModelLoader, ModelSaver | ||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
from tools.logger import * | |||||
# from tools.logger import * | |||||
# from model.TransformerModel import TransformerModel | # from model.TransformerModel import TransformerModel | ||||
from model.TForiginal import TransformerModel | from model.TForiginal import TransformerModel | ||||
from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric | |||||
from model.LSTMModel import SummarizationModel | |||||
from model.Metric import LossMetric, LabelFMetric, FastRougeMetric, PyRougeMetric | |||||
from model.Loss import MyCrossEntropyLoss | from model.Loss import MyCrossEntropyLoss | ||||
from tools.Callback import TrainCallback | from tools.Callback import TrainCallback | ||||
def setup_training(model, train_loader, valid_loader, hps): | def setup_training(model, train_loader, valid_loader, hps): | ||||
"""Does setup before starting training (run_training)""" | """Does setup before starting training (run_training)""" | ||||
@@ -60,32 +66,23 @@ def setup_training(model, train_loader, valid_loader, hps): | |||||
else: | else: | ||||
logger.info("[INFO] Create new model for training...") | logger.info("[INFO] Create new model for training...") | ||||
try: | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
except KeyboardInterrupt: | |||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||||
save_file = os.path.join(train_dir, "earlystop.pkl") | |||||
saver = ModelSaver(save_file) | |||||
saver.save_pytorch(model) | |||||
logger.info('[INFO] Saving early stop model to %s', save_file) | |||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||||
def run_training(model, train_loader, valid_loader, hps): | def run_training(model, train_loader, valid_loader, hps): | ||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||||
logger.info("[INFO] Starting run_training") | logger.info("[INFO] Starting run_training") | ||||
train_dir = os.path.join(hps.save_root, "train") | train_dir = os.path.join(hps.save_root, "train") | ||||
if not os.path.exists(train_dir): os.makedirs(train_dir) | |||||
if os.path.exists(train_dir): shutil.rmtree(train_dir) | |||||
os.makedirs(train_dir) | |||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | ||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir) | if not os.path.exists(eval_dir): os.makedirs(eval_dir) | ||||
lr = hps.lr | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr) | |||||
criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | ||||
# criterion = torch.nn.CrossEntropyLoss(reduce="none") | |||||
trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | ||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
metric_key="f", validate_every=-1, save_path=eval_dir, | |||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LossMetric(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none'), LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||||
metric_key="loss", validate_every=-1, save_path=eval_dir, | |||||
callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | ||||
train_info = trainer.train(load_best_model=True) | train_info = trainer.train(load_best_model=True) | ||||
@@ -98,8 +95,8 @@ def run_training(model, train_loader, valid_loader, hps): | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | ||||
def run_test(model, loader, hps, limited=False): | |||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||||
def run_test(model, loader, hps): | |||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | ||||
eval_dir = os.path.join(hps.save_root, "eval") | eval_dir = os.path.join(hps.save_root, "eval") | ||||
if not os.path.exists(test_dir) : os.makedirs(test_dir) | if not os.path.exists(test_dir) : os.makedirs(test_dir) | ||||
@@ -113,8 +110,8 @@ def run_test(model, loader, hps, limited=False): | |||||
train_dir = os.path.join(hps.save_root, "train") | train_dir = os.path.join(hps.save_root, "train") | ||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | ||||
else: | else: | ||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||||
logger.error("None of such model! Must be one of evalbestmodel/earlystop") | |||||
raise ValueError("None of such model! Must be one of evalbestmodel/earlystop") | |||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | ||||
modelloader = ModelLoader() | modelloader = ModelLoader() | ||||
@@ -174,13 +171,11 @@ def main(): | |||||
# Training | # Training | ||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | ||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | ||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | ||||
parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | ||||
# test | # test | ||||
parser.add_argument('-m', type=int, default=3, help='decode summary length') | parser.add_argument('-m', type=int, default=3, help='decode summary length') | ||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | ||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | ||||
@@ -195,21 +190,22 @@ def main(): | |||||
VOCAL_FILE = args.vocab_path | VOCAL_FILE = args.vocab_path | ||||
LOG_PATH = args.log_root | LOG_PATH = args.log_root | ||||
# train_log setting | |||||
# # train_log setting | |||||
if not os.path.exists(LOG_PATH): | if not os.path.exists(LOG_PATH): | ||||
if args.mode == "train": | if args.mode == "train": | ||||
os.makedirs(LOG_PATH) | os.makedirs(LOG_PATH) | ||||
else: | else: | ||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | ||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | ||||
log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | ||||
file_handler = logging.FileHandler(log_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
# logger = _init_logger(path=log_path) | |||||
# file_handler = logging.FileHandler(log_path) | |||||
# file_handler.setFormatter(formatter) | |||||
# logger.addHandler(file_handler) | |||||
logger.info("Pytorch %s", torch.__version__) | logger.info("Pytorch %s", torch.__version__) | ||||
# dataset | |||||
hps = args | hps = args | ||||
dbPipe = ExtCNNDMPipe(vocab_size=hps.vocab_size, | dbPipe = ExtCNNDMPipe(vocab_size=hps.vocab_size, | ||||
vocab_path=VOCAL_FILE, | vocab_path=VOCAL_FILE, | ||||
@@ -225,6 +221,8 @@ def main(): | |||||
paths = {"train": DATA_FILE, "valid": VALID_FILE} | paths = {"train": DATA_FILE, "valid": VALID_FILE} | ||||
db = dbPipe.process_from_file(paths) | db = dbPipe.process_from_file(paths) | ||||
# embedding | |||||
if args.embedding == "glove": | if args.embedding == "glove": | ||||
vocab = db.get_vocab("vocab") | vocab = db.get_vocab("vocab") | ||||
embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | ||||
@@ -237,19 +235,24 @@ def main(): | |||||
logger.error("[ERROR] embedding To Be Continued!") | logger.error("[ERROR] embedding To Be Continued!") | ||||
sys.exit(1) | sys.exit(1) | ||||
# model | |||||
if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | ||||
model_param = json.load(open("config/transformer.config", "rb")) | model_param = json.load(open("config/transformer.config", "rb")) | ||||
hps.__dict__.update(model_param) | hps.__dict__.update(model_param) | ||||
model = TransformerModel(hps, embed) | model = TransformerModel(hps, embed) | ||||
elif args.sentence_encoder == "deeplstm" and args.sentence_decoder == "SeqLab": | |||||
model_param = json.load(open("config/deeplstm.config", "rb")) | |||||
hps.__dict__.update(model_param) | |||||
model = SummarizationModel(hps, embed) | |||||
else: | else: | ||||
logger.error("[ERROR] Model To Be Continued!") | logger.error("[ERROR] Model To Be Continued!") | ||||
sys.exit(1) | sys.exit(1) | ||||
logger.info(hps) | |||||
if hps.cuda: | if hps.cuda: | ||||
model = model.cuda() | model = model.cuda() | ||||
logger.info("[INFO] Use cuda") | logger.info("[INFO] Use cuda") | ||||
logger.info(hps) | |||||
if hps.mode == 'train': | if hps.mode == 'train': | ||||
db.get_dataset("valid").set_target("text", "summary") | db.get_dataset("valid").set_target("text", "summary") | ||||
setup_training(model, db.get_dataset("train"), db.get_dataset("valid"), hps) | setup_training(model, db.get_dataset("train"), db.get_dataset("valid"), hps) | ||||