@@ -134,8 +134,8 @@ class Trainer(object): | |||
# main training procedure | |||
start = time.time() | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M')) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
print("training epochs started " + self.start_time) | |||
logger.info("training epochs started " + self.start_time) | |||
epoch, iters = 1, 0 | |||
while(1): | |||
@@ -51,6 +51,12 @@ class Vocabulary(object): | |||
self.min_freq = min_freq | |||
self.word_count = {} | |||
self.has_default = need_default | |||
if self.has_default: | |||
self.padding_label = DEFAULT_PADDING_LABEL | |||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
else: | |||
self.padding_label = None | |||
self.unknown_label = None | |||
self.word2idx = None | |||
self.idx2word = None | |||
@@ -77,12 +83,10 @@ class Vocabulary(object): | |||
""" | |||
if self.has_default: | |||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
self.padding_label = DEFAULT_PADDING_LABEL | |||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||
else: | |||
self.word2idx = {} | |||
self.padding_label = None | |||
self.unknown_label = None | |||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||
if self.min_freq is not None: | |||
@@ -135,9 +139,9 @@ class Vocabulary(object): | |||
return self.word2idx[self.unknown_label] | |||
def __setattr__(self, name, val): | |||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||
self.word2idx[val] = self.word2idx.pop(self.__dict__[name]) | |||
self.__dict__[name] = val | |||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||
self.word2idx = None | |||
@property | |||
@check_build_vocab | |||
@@ -16,10 +16,9 @@ def mst(scores): | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||
""" | |||
length = scores.shape[0] | |||
min_score = -np.inf | |||
mask = np.zeros((length, length)) | |||
np.fill_diagonal(mask, -np.inf) | |||
scores = scores + mask | |||
min_score = scores.min() - 1 | |||
eye = np.eye(length) | |||
scores = scores * (1 - eye) + min_score * eye | |||
heads = np.argmax(scores, axis=1) | |||
heads[0] = 0 | |||
tokens = np.arange(1, length) | |||
@@ -126,6 +125,8 @@ class GraphParser(nn.Module): | |||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||
_, seq_len, _ = arc_matrix.shape | |||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||
flip_mask = (seq_mask == 0).byte() | |||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
_, heads = torch.max(matrix, dim=2) | |||
if seq_mask is not None: | |||
heads *= seq_mask.long() | |||
@@ -135,8 +136,15 @@ class GraphParser(nn.Module): | |||
batch_size, seq_len, _ = arc_matrix.shape | |||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||
seq_mask[batch_idx, lens-1] = 0 | |||
for i, graph in enumerate(matrix): | |||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||
len_i = lens[i] | |||
if len_i == seq_len: | |||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||
else: | |||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||
if seq_mask is not None: | |||
ans *= seq_mask.long() | |||
return ans | |||
@@ -251,17 +259,18 @@ class BiaffineParser(GraphParser): | |||
self.normal_dropout = nn.Dropout(p=dropout) | |||
self.use_greedy_infer = use_greedy_infer | |||
self.reset_parameters() | |||
self.explore_p = 0.2 | |||
def reset_parameters(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Embedding): | |||
continue | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.weight, 0.1) | |||
nn.init.constant_(m.bias, 0) | |||
else: | |||
for p in m.parameters(): | |||
nn.init.normal_(p, 0, 0.01) | |||
nn.init.normal_(p, 0, 0.1) | |||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||
""" | |||
@@ -304,8 +313,6 @@ class BiaffineParser(GraphParser): | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
flip_mask = (seq_mask == 0) | |||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None or not self.training: | |||
@@ -317,8 +324,12 @@ class BiaffineParser(GraphParser): | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
head_pred = None | |||
heads = gold_heads | |||
if torch.rand(1).item() < self.explore_p: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
@@ -333,7 +344,7 @@ class BiaffineParser(GraphParser): | |||
Compute loss. | |||
:param arc_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, n_tags] | |||
:param head_indices: [batch_size, seq_len] | |||
:param head_labels: [batch_size, seq_len] | |||
:param seq_mask: [batch_size, seq_len] | |||
@@ -341,10 +352,13 @@ class BiaffineParser(GraphParser): | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||
flip_mask = (seq_mask == 0) | |||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
label_logits = F.log_softmax(label_pred, dim=2) | |||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
arc_loss = arc_logits[batch_index, child_index, head_indices] | |||
label_loss = label_logits[batch_index, child_index, head_labels] | |||
@@ -352,9 +366,8 @@ class BiaffineParser(GraphParser): | |||
label_loss = label_loss[:, 1:] | |||
float_mask = seq_mask[:, 1:].float() | |||
length = (seq_mask.sum() - batch_size).float() | |||
arc_nll = -(arc_loss*float_mask).sum() / length | |||
label_nll = -(label_loss*float_mask).sum() / length | |||
arc_nll = -(arc_loss*float_mask).mean() | |||
label_nll = -(label_loss*float_mask).mean() | |||
return arc_nll + label_nll | |||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||
mask_x = input.new_ones((batch_size, self.input_size)) | |||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | |||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | |||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||
hidden_list = [] | |||
for layer in range(self.num_layers): | |||
output_list = [] | |||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||
for direction in range(self.num_directions): | |||
input_x = input if direction == 0 else flip(input, [0]) | |||
idx = self.num_directions * layer + direction | |||
@@ -1,6 +1,6 @@ | |||
[train] | |||
epochs = -1 | |||
batch_size = 32 | |||
batch_size = 16 | |||
pickle_path = "./save/" | |||
validate = true | |||
save_best_dev = true | |||
@@ -37,4 +37,4 @@ use_greedy_infer=false | |||
[optim] | |||
lr = 2e-3 | |||
weight_decay = 0.0 | |||
weight_decay = 5e-5 |
@@ -24,6 +24,12 @@ from fastNLP.loader.embed_loader import EmbedLoader | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.saver.model_saver import ModelSaver | |||
BOS = '<BOS>' | |||
EOS = '<EOS>' | |||
UNK = '<OOV>' | |||
NUM = '<NUM>' | |||
ENG = '<ENG>' | |||
# not in the file's dir | |||
if len(os.path.dirname(__file__)) != 0: | |||
os.chdir(os.path.dirname(__file__)) | |||
@@ -97,10 +103,10 @@ class CTBDataLoader(object): | |||
def convert(self, data): | |||
dataset = DataSet() | |||
for sample in data: | |||
word_seq = ["<s>"] + sample[0] + ['</s>'] | |||
pos_seq = ["<s>"] + sample[1] + ['</s>'] | |||
word_seq = [BOS] + sample[0] + [EOS] | |||
pos_seq = [BOS] + sample[1] + [EOS] | |||
heads = [0] + list(map(int, sample[2])) + [0] | |||
head_tags = ["<s>"] + sample[3] + ['</s>'] | |||
head_tags = [BOS] + sample[3] + [EOS] | |||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||
pos_seq=TextField(pos_seq, is_target=False), | |||
gold_heads=SeqLabelField(heads, is_target=False), | |||
@@ -166,9 +172,9 @@ def P2(data, field, length): | |||
def P1(data, field): | |||
def reeng(w): | |||
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG' | |||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||
def renum(w): | |||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER' | |||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||
for ins in data: | |||
ori = ins[field].contents() | |||
s = list(map(renum, map(reeng, ori))) | |||
@@ -211,26 +217,32 @@ class ParserEvaluator(Evaluator): | |||
try: | |||
data_dict = load_data(processed_datadir) | |||
word_v = data_dict['word_v'] | |||
pos_v = data_dict['pos_v'] | |||
tag_v = data_dict['tag_v'] | |||
train_data = data_dict['train_data'] | |||
dev_data = data_dict['dev_data'] | |||
test_data = data_dict['test_datas'] | |||
test_data = data_dict['test_data'] | |||
print('use saved pickles') | |||
except Exception as _: | |||
print('load raw data and preprocess') | |||
# use pretrain embedding | |||
word_v = Vocabulary(need_default=True, min_freq=2) | |||
word_v.unknown_label = UNK | |||
pos_v = Vocabulary(need_default=True) | |||
tag_v = Vocabulary(need_default=False) | |||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v) | |||
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||
datasets = (train_data, dev_data, test_data) | |||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
word_v.unknown_label = "<OOV>" | |||
print(len(word_v)) | |||
print(embed.size()) | |||
# Model | |||
model_args['word_vocab_size'] = len(word_v) | |||
@@ -239,18 +251,14 @@ model_args['num_label'] = len(tag_v) | |||
model = BiaffineParser(**model_args.data) | |||
model.reset_parameters() | |||
datasets = (train_data, dev_data, test_data) | |||
for ds in datasets: | |||
# print('====='*30) | |||
P1(ds, 'word_seq') | |||
P2(ds, 'word_seq', 5) | |||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||
ds.set_origin_len('word_seq') | |||
if train_args['use_golden_train']: | |||
ds.set_target(gold_heads=False) | |||
else: | |||
ds.set_target(gold_heads=None) | |||
if train_args['use_golden_train']: | |||
train_data.set_target(gold_heads=False) | |||
else: | |||
train_data.set_target(gold_heads=None) | |||
train_args.data.pop('use_golden_train') | |||
ignore_label = pos_v['P'] | |||
@@ -274,7 +282,7 @@ def train(path): | |||
{'params': list(embed_params), 'lr':lr*0.1}, | |||
{'params': list(decay_params), **optim_args.data}, | |||
{'params': params} | |||
], lr=lr) | |||
], lr=lr, betas=(0.9, 0.9)) | |||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||
def _update(obj): | |||
@@ -315,7 +323,7 @@ def test(path): | |||
# Model | |||
model = BiaffineParser(**model_args.data) | |||
model.eval() | |||
try: | |||
ModelLoader.load_pytorch(model, path) | |||
print('model parameter loaded!') | |||
@@ -324,6 +332,8 @@ def test(path): | |||
raise | |||
# Start training | |||
print("Testing Train data") | |||
tester.test(model, train_data) | |||
print("Testing Dev data") | |||
tester.test(model, dev_data) | |||
print("Testing Test data") | |||