Browse Source

update parser, fix bugs varrnn & vocab

tags/v0.2.0
yunfan 5 years ago
parent
commit
053249420f
6 changed files with 77 additions and 50 deletions
  1. +2
    -2
      fastNLP/core/trainer.py
  2. +10
    -6
      fastNLP/core/vocabulary.py
  3. +31
    -18
      fastNLP/models/biaffine_parser.py
  4. +2
    -2
      fastNLP/modules/encoder/variational_rnn.py
  5. +2
    -2
      reproduction/Biaffine_parser/cfg.cfg
  6. +30
    -20
      reproduction/Biaffine_parser/run.py

+ 2
- 2
fastNLP/core/trainer.py View File

@@ -134,8 +134,8 @@ class Trainer(object):

# main training procedure
start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M'))
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0
while(1):


+ 10
- 6
fastNLP/core/vocabulary.py View File

@@ -51,6 +51,12 @@ class Vocabulary(object):
self.min_freq = min_freq
self.word_count = {}
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.padding_label = None
self.unknown_label = None
self.word2idx = None
self.idx2word = None

@@ -77,12 +83,10 @@ class Vocabulary(object):
"""
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL)
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL)
else:
self.word2idx = {}
self.padding_label = None
self.unknown_label = None

words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True)
if self.min_freq is not None:
@@ -135,9 +139,9 @@ class Vocabulary(object):
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
self.word2idx[val] = self.word2idx.pop(self.__dict__[name])
self.__dict__[name] = val
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
self.word2idx = None

@property
@check_build_vocab


+ 31
- 18
fastNLP/models/biaffine_parser.py View File

@@ -16,10 +16,9 @@ def mst(scores):
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = -np.inf
mask = np.zeros((length, length))
np.fill_diagonal(mask, -np.inf)
scores = scores + mask
min_score = scores.min() - 1
eye = np.eye(length)
scores = scores * (1 - eye) + min_score * eye
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
@@ -126,6 +125,8 @@ class GraphParser(nn.Module):
def _greedy_decoder(self, arc_matrix, seq_mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (seq_mask == 0).byte()
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if seq_mask is not None:
heads *= seq_mask.long()
@@ -135,8 +136,15 @@ class GraphParser(nn.Module):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device)
seq_mask[batch_idx, lens-1] = 0
for i, graph in enumerate(matrix):
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
len_i = lens[i]
if len_i == seq_len:
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
else:
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device)
if seq_mask is not None:
ans *= seq_mask.long()
return ans
@@ -251,17 +259,18 @@ class BiaffineParser(GraphParser):
self.normal_dropout = nn.Dropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
self.reset_parameters()
self.explore_p = 0.2

def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.weight, 0.1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.01)
nn.init.normal_(p, 0, 0.1)

def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
@@ -304,8 +313,6 @@ class BiaffineParser(GraphParser):

# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
flip_mask = (seq_mask == 0)
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)

# use gold or predicted arc to predict label
if gold_heads is None or not self.training:
@@ -317,8 +324,12 @@ class BiaffineParser(GraphParser):
head_pred = heads
else:
assert self.training # must be training mode
head_pred = None
heads = gold_heads
if torch.rand(1).item() < self.explore_p:
heads = self._greedy_decoder(arc_pred, seq_mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads

batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
@@ -333,7 +344,7 @@ class BiaffineParser(GraphParser):
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, n_tags]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param seq_mask: [batch_size, seq_len]
@@ -341,10 +352,13 @@ class BiaffineParser(GraphParser):
"""

batch_size, seq_len, _ = arc_pred.shape
arc_logits = F.log_softmax(arc_pred, dim=2)
flip_mask = (seq_mask == 0)
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred)
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1)
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]

@@ -352,9 +366,8 @@ class BiaffineParser(GraphParser):
label_loss = label_loss[:, 1:]

float_mask = seq_mask[:, 1:].float()
length = (seq_mask.sum() - batch_size).float()
arc_nll = -(arc_loss*float_mask).sum() / length
label_nll = -(label_loss*float_mask).sum() / length
arc_nll = -(arc_loss*float_mask).mean()
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll



+ 2
- 2
fastNLP/modules/encoder/variational_rnn.py View File

@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module):

mask_x = input.new_ones((batch_size, self.input_size))
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions))
mask_h = input.new_ones((batch_size, self.hidden_size))
mask_h_ones = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True)

hidden_list = []
for layer in range(self.num_layers):
output_list = []
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False)
for direction in range(self.num_directions):
input_x = input if direction == 0 else flip(input, [0])
idx = self.num_directions * layer + direction


+ 2
- 2
reproduction/Biaffine_parser/cfg.cfg View File

@@ -1,6 +1,6 @@
[train]
epochs = -1
batch_size = 32
batch_size = 16
pickle_path = "./save/"
validate = true
save_best_dev = true
@@ -37,4 +37,4 @@ use_greedy_infer=false

[optim]
lr = 2e-3
weight_decay = 0.0
weight_decay = 5e-5

+ 30
- 20
reproduction/Biaffine_parser/run.py View File

@@ -24,6 +24,12 @@ from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.saver.model_saver import ModelSaver

BOS = '<BOS>'
EOS = '<EOS>'
UNK = '<OOV>'
NUM = '<NUM>'
ENG = '<ENG>'

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))
@@ -97,10 +103,10 @@ class CTBDataLoader(object):
def convert(self, data):
dataset = DataSet()
for sample in data:
word_seq = ["<s>"] + sample[0] + ['</s>']
pos_seq = ["<s>"] + sample[1] + ['</s>']
word_seq = [BOS] + sample[0] + [EOS]
pos_seq = [BOS] + sample[1] + [EOS]
heads = [0] + list(map(int, sample[2])) + [0]
head_tags = ["<s>"] + sample[3] + ['</s>']
head_tags = [BOS] + sample[3] + [EOS]
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False),
gold_heads=SeqLabelField(heads, is_target=False),
@@ -166,9 +172,9 @@ def P2(data, field, length):

def P1(data, field):
def reeng(w):
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG'
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG
def renum(w):
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER'
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM
for ins in data:
ori = ins[field].contents()
s = list(map(renum, map(reeng, ori)))
@@ -211,26 +217,32 @@ class ParserEvaluator(Evaluator):

try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
pos_v = data_dict['pos_v']
tag_v = data_dict['tag_v']
train_data = data_dict['train_data']
dev_data = data_dict['dev_data']
test_data = data_dict['test_datas']
test_data = data_dict['test_data']
print('use saved pickles')

except Exception as _:
print('load raw data and preprocess')
# use pretrain embedding
word_v = Vocabulary(need_default=True, min_freq=2)
word_v.unknown_label = UNK
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name))
test_data = loader.load(os.path.join(datadir, test_data_name))
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
datasets = (train_data, dev_data, test_data)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)

embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))

embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl'))
word_v.unknown_label = "<OOV>"
print(len(word_v))
print(embed.size())

# Model
model_args['word_vocab_size'] = len(word_v)
@@ -239,18 +251,14 @@ model_args['num_label'] = len(tag_v)

model = BiaffineParser(**model_args.data)
model.reset_parameters()

datasets = (train_data, dev_data, test_data)
for ds in datasets:
# print('====='*30)
P1(ds, 'word_seq')
P2(ds, 'word_seq', 5)
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
ds.set_origin_len('word_seq')
if train_args['use_golden_train']:
ds.set_target(gold_heads=False)
else:
ds.set_target(gold_heads=None)
if train_args['use_golden_train']:
train_data.set_target(gold_heads=False)
else:
train_data.set_target(gold_heads=None)
train_args.data.pop('use_golden_train')
ignore_label = pos_v['P']

@@ -274,7 +282,7 @@ def train(path):
{'params': list(embed_params), 'lr':lr*0.1},
{'params': list(decay_params), **optim_args.data},
{'params': params}
], lr=lr)
], lr=lr, betas=(0.9, 0.9))
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))

def _update(obj):
@@ -315,7 +323,7 @@ def test(path):

# Model
model = BiaffineParser(**model_args.data)
model.eval()
try:
ModelLoader.load_pytorch(model, path)
print('model parameter loaded!')
@@ -324,6 +332,8 @@ def test(path):
raise

# Start training
print("Testing Train data")
tester.test(model, train_data)
print("Testing Dev data")
tester.test(model, dev_data)
print("Testing Test data")


Loading…
Cancel
Save