@@ -134,8 +134,8 @@ class Trainer(object): | |||||
# main training procedure | # main training procedure | ||||
start = time.time() | start = time.time() | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M')) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
print("training epochs started " + self.start_time) | |||||
logger.info("training epochs started " + self.start_time) | logger.info("training epochs started " + self.start_time) | ||||
epoch, iters = 1, 0 | epoch, iters = 1, 0 | ||||
while(1): | while(1): | ||||
@@ -51,6 +51,12 @@ class Vocabulary(object): | |||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = {} | self.word_count = {} | ||||
self.has_default = need_default | self.has_default = need_default | ||||
if self.has_default: | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
@@ -77,12 +83,10 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if self.has_default: | if self.has_default: | ||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||||
else: | else: | ||||
self.word2idx = {} | self.word2idx = {} | ||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
@@ -135,9 +139,9 @@ class Vocabulary(object): | |||||
return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
def __setattr__(self, name, val): | def __setattr__(self, name, val): | ||||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||||
self.word2idx[val] = self.word2idx.pop(self.__dict__[name]) | |||||
self.__dict__[name] = val | self.__dict__[name] = val | ||||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||||
self.word2idx = None | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
@@ -16,10 +16,9 @@ def mst(scores): | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
""" | """ | ||||
length = scores.shape[0] | length = scores.shape[0] | ||||
min_score = -np.inf | |||||
mask = np.zeros((length, length)) | |||||
np.fill_diagonal(mask, -np.inf) | |||||
scores = scores + mask | |||||
min_score = scores.min() - 1 | |||||
eye = np.eye(length) | |||||
scores = scores * (1 - eye) + min_score * eye | |||||
heads = np.argmax(scores, axis=1) | heads = np.argmax(scores, axis=1) | ||||
heads[0] = 0 | heads[0] = 0 | ||||
tokens = np.arange(1, length) | tokens = np.arange(1, length) | ||||
@@ -126,6 +125,8 @@ class GraphParser(nn.Module): | |||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | def _greedy_decoder(self, arc_matrix, seq_mask=None): | ||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (seq_mask == 0).byte() | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if seq_mask is not None: | if seq_mask is not None: | ||||
heads *= seq_mask.long() | heads *= seq_mask.long() | ||||
@@ -135,8 +136,15 @@ class GraphParser(nn.Module): | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||||
seq_mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
len_i = lens[i] | |||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if seq_mask is not None: | if seq_mask is not None: | ||||
ans *= seq_mask.long() | ans *= seq_mask.long() | ||||
return ans | return ans | ||||
@@ -251,17 +259,18 @@ class BiaffineParser(GraphParser): | |||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
self.reset_parameters() | self.reset_parameters() | ||||
self.explore_p = 0.2 | |||||
def reset_parameters(self): | def reset_parameters(self): | ||||
for m in self.modules(): | for m in self.modules(): | ||||
if isinstance(m, nn.Embedding): | if isinstance(m, nn.Embedding): | ||||
continue | continue | ||||
elif isinstance(m, nn.LayerNorm): | elif isinstance(m, nn.LayerNorm): | ||||
nn.init.constant_(m.weight, 1) | |||||
nn.init.constant_(m.weight, 0.1) | |||||
nn.init.constant_(m.bias, 0) | nn.init.constant_(m.bias, 0) | ||||
else: | else: | ||||
for p in m.parameters(): | for p in m.parameters(): | ||||
nn.init.normal_(p, 0, 0.01) | |||||
nn.init.normal_(p, 0, 0.1) | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | ||||
""" | """ | ||||
@@ -304,8 +313,6 @@ class BiaffineParser(GraphParser): | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
flip_mask = (seq_mask == 0) | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None or not self.training: | if gold_heads is None or not self.training: | ||||
@@ -317,8 +324,12 @@ class BiaffineParser(GraphParser): | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
head_pred = None | |||||
heads = gold_heads | |||||
if torch.rand(1).item() < self.explore_p: | |||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | ||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
@@ -333,7 +344,7 @@ class BiaffineParser(GraphParser): | |||||
Compute loss. | Compute loss. | ||||
:param arc_pred: [batch_size, seq_len, seq_len] | :param arc_pred: [batch_size, seq_len, seq_len] | ||||
:param label_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param head_indices: [batch_size, seq_len] | :param head_indices: [batch_size, seq_len] | ||||
:param head_labels: [batch_size, seq_len] | :param head_labels: [batch_size, seq_len] | ||||
:param seq_mask: [batch_size, seq_len] | :param seq_mask: [batch_size, seq_len] | ||||
@@ -341,10 +352,13 @@ class BiaffineParser(GraphParser): | |||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
flip_mask = (seq_mask == 0) | |||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, head_indices] | arc_loss = arc_logits[batch_index, child_index, head_indices] | ||||
label_loss = label_logits[batch_index, child_index, head_labels] | label_loss = label_logits[batch_index, child_index, head_labels] | ||||
@@ -352,9 +366,8 @@ class BiaffineParser(GraphParser): | |||||
label_loss = label_loss[:, 1:] | label_loss = label_loss[:, 1:] | ||||
float_mask = seq_mask[:, 1:].float() | float_mask = seq_mask[:, 1:].float() | ||||
length = (seq_mask.sum() - batch_size).float() | |||||
arc_nll = -(arc_loss*float_mask).sum() / length | |||||
label_nll = -(label_loss*float_mask).sum() / length | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | mask_x = input.new_ones((batch_size, self.input_size)) | ||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | ||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden_list = [] | hidden_list = [] | ||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | input_x = input if direction == 0 else flip(input, [0]) | ||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
@@ -1,6 +1,6 @@ | |||||
[train] | [train] | ||||
epochs = -1 | epochs = -1 | ||||
batch_size = 32 | |||||
batch_size = 16 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
@@ -37,4 +37,4 @@ use_greedy_infer=false | |||||
[optim] | [optim] | ||||
lr = 2e-3 | lr = 2e-3 | ||||
weight_decay = 0.0 | |||||
weight_decay = 5e-5 |
@@ -24,6 +24,12 @@ from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
BOS = '<BOS>' | |||||
EOS = '<EOS>' | |||||
UNK = '<OOV>' | |||||
NUM = '<NUM>' | |||||
ENG = '<ENG>' | |||||
# not in the file's dir | # not in the file's dir | ||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
@@ -97,10 +103,10 @@ class CTBDataLoader(object): | |||||
def convert(self, data): | def convert(self, data): | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq = ["<s>"] + sample[0] + ['</s>'] | |||||
pos_seq = ["<s>"] + sample[1] + ['</s>'] | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | heads = [0] + list(map(int, sample[2])) + [0] | ||||
head_tags = ["<s>"] + sample[3] + ['</s>'] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | ||||
pos_seq=TextField(pos_seq, is_target=False), | pos_seq=TextField(pos_seq, is_target=False), | ||||
gold_heads=SeqLabelField(heads, is_target=False), | gold_heads=SeqLabelField(heads, is_target=False), | ||||
@@ -166,9 +172,9 @@ def P2(data, field, length): | |||||
def P1(data, field): | def P1(data, field): | ||||
def reeng(w): | def reeng(w): | ||||
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG' | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | def renum(w): | ||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER' | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | for ins in data: | ||||
ori = ins[field].contents() | ori = ins[field].contents() | ||||
s = list(map(renum, map(reeng, ori))) | s = list(map(renum, map(reeng, ori))) | ||||
@@ -211,26 +217,32 @@ class ParserEvaluator(Evaluator): | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
word_v = data_dict['word_v'] | |||||
pos_v = data_dict['pos_v'] | pos_v = data_dict['pos_v'] | ||||
tag_v = data_dict['tag_v'] | tag_v = data_dict['tag_v'] | ||||
train_data = data_dict['train_data'] | train_data = data_dict['train_data'] | ||||
dev_data = data_dict['dev_data'] | dev_data = data_dict['dev_data'] | ||||
test_data = data_dict['test_datas'] | |||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | print('use saved pickles') | ||||
except Exception as _: | except Exception as _: | ||||
print('load raw data and preprocess') | print('load raw data and preprocess') | ||||
# use pretrain embedding | # use pretrain embedding | ||||
word_v = Vocabulary(need_default=True, min_freq=2) | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | pos_v = Vocabulary(need_default=True) | ||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | train_data = loader.load(os.path.join(datadir, train_data_name)) | ||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | test_data = loader.load(os.path.join(datadir, test_data_name)) | ||||
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v) | |||||
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
word_v.unknown_label = "<OOV>" | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
# Model | # Model | ||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
@@ -239,18 +251,14 @@ model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | model.reset_parameters() | ||||
datasets = (train_data, dev_data, test_data) | datasets = (train_data, dev_data, test_data) | ||||
for ds in datasets: | for ds in datasets: | ||||
# print('====='*30) | |||||
P1(ds, 'word_seq') | |||||
P2(ds, 'word_seq', 5) | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | ||||
ds.set_origin_len('word_seq') | ds.set_origin_len('word_seq') | ||||
if train_args['use_golden_train']: | |||||
ds.set_target(gold_heads=False) | |||||
else: | |||||
ds.set_target(gold_heads=None) | |||||
if train_args['use_golden_train']: | |||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_args.data.pop('use_golden_train') | train_args.data.pop('use_golden_train') | ||||
ignore_label = pos_v['P'] | ignore_label = pos_v['P'] | ||||
@@ -274,7 +282,7 @@ def train(path): | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | {'params': list(embed_params), 'lr':lr*0.1}, | ||||
{'params': list(decay_params), **optim_args.data}, | {'params': list(decay_params), **optim_args.data}, | ||||
{'params': params} | {'params': params} | ||||
], lr=lr) | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | ||||
def _update(obj): | def _update(obj): | ||||
@@ -315,7 +323,7 @@ def test(path): | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.eval() | |||||
try: | try: | ||||
ModelLoader.load_pytorch(model, path) | ModelLoader.load_pytorch(model, path) | ||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
@@ -324,6 +332,8 @@ def test(path): | |||||
raise | raise | ||||
# Start training | # Start training | ||||
print("Testing Train data") | |||||
tester.test(model, train_data) | |||||
print("Testing Dev data") | print("Testing Dev data") | ||||
tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
print("Testing Test data") | print("Testing Test data") | ||||