@@ -35,23 +35,21 @@ class SeqLabelEvaluator(Evaluator): | |||||
def __init__(self): | def __init__(self): | ||||
super(SeqLabelEvaluator, self).__init__() | super(SeqLabelEvaluator, self).__init__() | ||||
def __call__(self, predict, truth): | |||||
def __call__(self, predict, truth, **_): | |||||
""" | """ | ||||
:param predict: list of dict, the network outputs from all batches. | :param predict: list of dict, the network outputs from all batches. | ||||
:param truth: list of dict, the ground truths from all batch_y. | :param truth: list of dict, the ground truths from all batch_y. | ||||
:return accuracy: | :return accuracy: | ||||
""" | """ | ||||
truth = [item["truth"] for item in truth] | |||||
predict = [item["predict"] for item in predict] | |||||
total_correct, total_count = 0., 0. | |||||
total_correct, total_count = 0., 0. | |||||
for x, y in zip(predict, truth): | for x, y in zip(predict, truth): | ||||
# x = torch.tensor(x) | # x = torch.tensor(x) | ||||
y = y.to(x) # make sure they are in the same device | y = y.to(x) # make sure they are in the same device | ||||
mask = x.ge(1).long() | |||||
correct = torch.sum(x * mask == y * mask) - torch.sum(x.le(0)) | |||||
mask = (y > 0) | |||||
correct = torch.sum(((x == y) * mask).long()) | |||||
total_correct += float(correct) | total_correct += float(correct) | ||||
total_count += float(torch.sum(mask)) | |||||
total_count += float(torch.sum(mask.long())) | |||||
accuracy = total_correct / total_count | accuracy = total_correct / total_count | ||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
@@ -1,4 +1,5 @@ | |||||
import torch | import torch | ||||
from collections import defaultdict | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.metrics import Evaluator | from fastNLP.core.metrics import Evaluator | ||||
@@ -71,17 +72,18 @@ class Tester(object): | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, is_test=True) | self.mode(network, is_test=True) | ||||
self.eval_history.clear() | self.eval_history.clear() | ||||
output_list = [] | |||||
truth_list = [] | |||||
output, truths = defaultdict(list), defaultdict(list) | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
assert isinstance(prediction, dict) | |||||
for k, v in prediction.items(): | |||||
output[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
eval_results = self.evaluate(**output, **truths) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | print("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
self.mode(network, is_test=False) | self.mode(network, is_test=False) | ||||
@@ -105,14 +107,10 @@ class Tester(object): | |||||
y = network(**x) | y = network(**x) | ||||
return y | return y | ||||
def evaluate(self, predict, truth): | |||||
def evaluate(self, **kwargs): | |||||
"""Compute evaluation metrics. | """Compute evaluation metrics. | ||||
:param predict: list of Tensor | |||||
:param truth: list of dict | |||||
:return eval_results: can be anything. It will be stored in self.eval_history | |||||
""" | """ | ||||
return self._evaluator(predict, truth) | |||||
return self._evaluator(**kwargs) | |||||
def print_eval_results(self, results): | def print_eval_results(self, results): | ||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
@@ -47,7 +47,8 @@ class Trainer(object): | |||||
"valid_step": 500, "eval_sort_key": 'acc', | "valid_step": 500, "eval_sort_key": 'acc', | ||||
"loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
"evaluator": Evaluator() | |||||
"eval_batch_size": 64, | |||||
"evaluator": Evaluator(), | |||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
@@ -78,6 +79,7 @@ class Trainer(object): | |||||
self.n_epochs = int(default_args["epochs"]) | self.n_epochs = int(default_args["epochs"]) | ||||
self.batch_size = int(default_args["batch_size"]) | self.batch_size = int(default_args["batch_size"]) | ||||
self.eval_batch_size = int(default_args['eval_batch_size']) | |||||
self.pickle_path = default_args["pickle_path"] | self.pickle_path = default_args["pickle_path"] | ||||
self.validate = default_args["validate"] | self.validate = default_args["validate"] | ||||
self.save_best_dev = default_args["save_best_dev"] | self.save_best_dev = default_args["save_best_dev"] | ||||
@@ -98,6 +100,8 @@ class Trainer(object): | |||||
self._best_accuracy = 0.0 | self._best_accuracy = 0.0 | ||||
self.eval_sort_key = default_args['eval_sort_key'] | self.eval_sort_key = default_args['eval_sort_key'] | ||||
self.validator = None | self.validator = None | ||||
self.epoch = 0 | |||||
self.step = 0 | |||||
def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
"""General Training Procedure | """General Training Procedure | ||||
@@ -118,7 +122,7 @@ class Trainer(object): | |||||
# define Tester over dev data | # define Tester over dev data | ||||
self.dev_data = None | self.dev_data = None | ||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path, | |||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | "use_cuda": self.use_cuda, "evaluator": self._evaluator} | ||||
if self.validator is None: | if self.validator is None: | ||||
self.validator = self._create_validator(default_valid_args) | self.validator = self._create_validator(default_valid_args) | ||||
@@ -139,9 +143,9 @@ class Trainer(object): | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
print("training epochs started " + self.start_time) | print("training epochs started " + self.start_time) | ||||
logger.info("training epochs started " + self.start_time) | logger.info("training epochs started " + self.start_time) | ||||
epoch, iters = 1, 0 | |||||
while epoch <= self.n_epochs: | |||||
logger.info("training epoch {}".format(epoch)) | |||||
self.epoch, self.step = 1, 0 | |||||
while self.epoch <= self.n_epochs: | |||||
logger.info("training epoch {}".format(self.epoch)) | |||||
# prepare mini-batch iterator | # prepare mini-batch iterator | ||||
data_iterator = Batch(train_data, batch_size=self.batch_size, | data_iterator = Batch(train_data, batch_size=self.batch_size, | ||||
@@ -150,14 +154,13 @@ class Trainer(object): | |||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
# one forward and backward pass | # one forward and backward pass | ||||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, | |||||
step=iters, dev_data=dev_data) | |||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data) | |||||
# validation | # validation | ||||
if self.validate: | if self.validate: | ||||
self.valid_model() | self.valid_model() | ||||
self.save_model(self._model, 'training_model_' + self.start_time) | self.save_model(self._model, 'training_model_' + self.start_time) | ||||
epoch += 1 | |||||
self.epoch += 1 | |||||
def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
@@ -167,7 +170,6 @@ class Trainer(object): | |||||
- start: time.time(), the starting time of this step. | - start: time.time(), the starting time of this step. | ||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
step = kwargs['step'] | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -177,25 +179,31 @@ class Trainer(object): | |||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
self.update() | self.update() | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self._model.named_parameters(): | for name, param in self._model.named_parameters(): | ||||
if param.requires_grad: | if param.requires_grad: | ||||
<<<<<<< HEAD | |||||
# self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) | # self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) | ||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) | # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) | ||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) | # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) | ||||
pass | pass | ||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | ||||
======= | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: | |||||
>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder | |||||
end = time.time() | end = time.time() | ||||
diff = timedelta(seconds=round(end - kwargs["start"])) | diff = timedelta(seconds=round(end - kwargs["start"])) | ||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | ||||
kwargs["epoch"], step, loss.data, diff) | |||||
self.epoch, self.step, loss.data, diff) | |||||
print(print_output) | print(print_output) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: | |||||
if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0: | |||||
self.valid_model() | self.valid_model() | ||||
step += 1 | |||||
return step | |||||
self.step += 1 | |||||
def valid_model(self): | def valid_model(self): | ||||
if self.dev_data is None: | if self.dev_data is None: | ||||
@@ -203,6 +211,8 @@ class Trainer(object): | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | ||||
logger.info("validation started") | logger.info("validation started") | ||||
res = self.validator.test(self._model, self.dev_data) | res = self.validator.test(self._model, self.dev_data) | ||||
for name, num in res.items(): | |||||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | |||||
if self.save_best_dev and self.best_eval_result(res): | if self.save_best_dev and self.best_eval_result(res): | ||||
logger.info('save best result! {}'.format(res)) | logger.info('save best result! {}'.format(res)) | ||||
print('save best result! {}'.format(res)) | print('save best result! {}'.format(res)) | ||||
@@ -10,6 +10,7 @@ from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.utils import seq_mask | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -123,31 +124,31 @@ class GraphParser(BaseModel): | |||||
def forward(self, x): | def forward(self, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||||
def _greedy_decoder(self, arc_matrix, mask=None): | |||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (seq_mask == 0).byte() | |||||
flip_mask = (mask == 0).byte() | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if seq_mask is not None: | |||||
heads *= seq_mask.long() | |||||
if mask is not None: | |||||
heads *= mask.long() | |||||
return heads | return heads | ||||
def _mst_decoder(self, arc_matrix, seq_mask=None): | |||||
def _mst_decoder(self, arc_matrix, mask=None): | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | ||||
seq_mask[batch_idx, lens-1] = 0 | |||||
mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
len_i = lens[i] | len_i = lens[i] | ||||
if len_i == seq_len: | if len_i == seq_len: | ||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | ||||
else: | else: | ||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | ||||
if seq_mask is not None: | |||||
ans *= seq_mask.long() | |||||
if mask is not None: | |||||
ans *= mask.long() | |||||
return ans | return ans | ||||
@@ -191,13 +192,6 @@ class LabelBilinear(nn.Module): | |||||
output += self.lin(torch.cat([x1, x2], dim=2)) | output += self.lin(torch.cat([x1, x2], dim=2)) | ||||
return output | return output | ||||
def len2masks(origin_len, max_len): | |||||
if origin_len.dim() <= 1: | |||||
origin_len = origin_len.unsqueeze(1) # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] | |||||
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] | |||||
return seq_mask | |||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
"""Biaffine Dependency Parser implemantation. | """Biaffine Dependency Parser implemantation. | ||||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | ||||
@@ -277,12 +271,12 @@ class BiaffineParser(GraphParser): | |||||
""" | """ | ||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | :param word_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | :param pos_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param seq_mask: [batch_size, seq_len] sequence of length masks | |||||
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks | |||||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | :param gold_heads: [batch_size, seq_len] sequence of golden heads | ||||
:return dict: parsing results | :return dict: parsing results | ||||
arc_pred: [batch_size, seq_len, seq_len] | arc_pred: [batch_size, seq_len, seq_len] | ||||
label_pred: [batch_size, seq_len, seq_len] | label_pred: [batch_size, seq_len, seq_len] | ||||
seq_mask: [batch_size, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
@@ -294,7 +288,7 @@ class BiaffineParser(GraphParser): | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | |||||
mask = seq_mask(word_seq_origin_len, seq_len).long() | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
@@ -327,14 +321,14 @@ class BiaffineParser(GraphParser): | |||||
if gold_heads is None or not self.training: | if gold_heads is None or not self.training: | ||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
else: | else: | ||||
heads = self._mst_decoder(arc_pred, seq_mask) | |||||
heads = self._mst_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
if torch.rand(1).item() < self.explore_p: | if torch.rand(1).item() < self.explore_p: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | head_pred = None | ||||
@@ -343,12 +337,12 @@ class BiaffineParser(GraphParser): | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | ||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
if head_pred is not None: | if head_pred is not None: | ||||
res_dict['head_pred'] = head_pred | res_dict['head_pred'] = head_pred | ||||
return res_dict | return res_dict | ||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): | |||||
""" | """ | ||||
Compute loss. | Compute loss. | ||||
@@ -356,12 +350,12 @@ class BiaffineParser(GraphParser): | |||||
:param label_pred: [batch_size, seq_len, n_tags] | :param label_pred: [batch_size, seq_len, n_tags] | ||||
:param head_indices: [batch_size, seq_len] | :param head_indices: [batch_size, seq_len] | ||||
:param head_labels: [batch_size, seq_len] | :param head_labels: [batch_size, seq_len] | ||||
:param seq_mask: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:return: loss value | :return: loss value | ||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (seq_mask == 0) | |||||
flip_mask = (mask == 0) | |||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | _arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | ||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
@@ -374,7 +368,7 @@ class BiaffineParser(GraphParser): | |||||
arc_loss = arc_loss[:, 1:] | arc_loss = arc_loss[:, 1:] | ||||
label_loss = label_loss[:, 1:] | label_loss = label_loss[:, 1:] | ||||
float_mask = seq_mask[:, 1:].float() | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | arc_nll = -(arc_loss*float_mask).mean() | ||||
label_nll = -(label_loss*float_mask).mean() | label_nll = -(label_loss*float_mask).mean() | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
@@ -4,20 +4,7 @@ import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
def seq_mask(seq_len, max_len): | |||||
"""Create a mask for the sequences. | |||||
:param seq_len: list or torch.LongTensor | |||||
:param max_len: int | |||||
:return mask: torch.LongTensor | |||||
""" | |||||
if isinstance(seq_len, list): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
from fastNLP.modules.utils import seq_mask | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -82,7 +69,7 @@ class SeqLabeling(BaseModel): | |||||
def make_mask(self, x, seq_len): | def make_mask(self, x, seq_len): | ||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = seq_mask(seq_len, max_len) | mask = seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | |||||
mask = mask.view(batch_size, max_len) | |||||
mask = mask.to(x).float() | mask = mask.to(x).float() | ||||
return mask | return mask | ||||
@@ -114,16 +101,20 @@ class AdvSeqLabel(SeqLabeling): | |||||
word_emb_dim = args["word_emb_dim"] | word_emb_dim = args["word_emb_dim"] | ||||
hidden_dim = args["rnn_hidden_units"] | hidden_dim = args["rnn_hidden_units"] | ||||
num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
dropout = args['dropout'] | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=1, dropout=0.5, bidirectional=True) | |||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.ReLU() | |||||
self.drop = torch.nn.Dropout(0.5) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.LeakyReLU() | |||||
self.drop = torch.nn.Dropout(dropout) | |||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
""" | """ | ||||
@@ -135,12 +126,10 @@ class AdvSeqLabel(SeqLabeling): | |||||
""" | """ | ||||
word_seq = word_seq.long() | word_seq = word_seq.long() | ||||
word_seq_origin_len = word_seq_origin_len.long() | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
word_seq_origin_len = word_seq_origin_len.cpu().numpy() | |||||
sent_len, idx_sort = np.sort(word_seq_origin_len)[::-1], np.argsort(-word_seq_origin_len) | |||||
idx_unsort = np.argsort(idx_sort) | |||||
idx_sort = torch.from_numpy(idx_sort) | |||||
idx_unsort = torch.from_numpy(idx_unsort) | |||||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, descending=False) | |||||
# word_seq_origin_len = word_seq_origin_len.long() | # word_seq_origin_len = word_seq_origin_len.long() | ||||
truth = truth.long() if truth is not None else None | truth = truth.long() if truth is not None else None | ||||
@@ -155,26 +144,28 @@ class AdvSeqLabel(SeqLabeling): | |||||
truth = truth.cuda() if truth is not None else None | truth = truth.cuda() if truth is not None else None | ||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
x = self.norm1(x) | |||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
sent_variable = x.index_select(0, idx_sort) | |||||
sent_variable = x[idx_sort] | |||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | ||||
x = self.Rnn(sent_packed) | |||||
x, _ = self.Rnn(sent_packed) | |||||
# print(x) | # print(x) | ||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | ||||
x = sent_output.index_select(0, idx_unsort) | |||||
x = sent_output[idx_unsort] | |||||
x = x.contiguous() | x = x.contiguous() | ||||
x = x.view(batch_size * max_len, -1) | |||||
# x = x.view(batch_size * max_len, -1) | |||||
x = self.Linear1(x) | x = self.Linear1(x) | ||||
# x = self.batch_norm(x) | # x = self.batch_norm(x) | ||||
x = self.norm2(x) | |||||
x = self.relu(x) | x = self.relu(x) | ||||
x = self.drop(x) | x = self.drop(x) | ||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
x = x.view(batch_size, max_len, -1) | |||||
# x = x.view(batch_size, max_len, -1) | |||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | return {"loss": self._internal_loss(x, truth) if truth is not None else None, | ||||
"predict": self.decode(x)} | "predict": self.decode(x)} | ||||
@@ -183,41 +174,45 @@ class AdvSeqLabel(SeqLabeling): | |||||
out = self.forward(**x) | out = self.forward(**x) | ||||
return {"predict": out["predict"]} | return {"predict": out["predict"]} | ||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
def loss(self, **kwargs): | |||||
assert 'loss' in kwargs | |||||
return kwargs['loss'] | |||||
if __name__ == '__main__': | |||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
@@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None): | |||||
def seq_mask(seq_len, max_len): | def seq_mask(seq_len, max_len): | ||||
"""Create sequence mask. | """Create sequence mask. | ||||
:param seq_len: list of int, the lengths of sequences in a batch. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | :param max_len: int, the maximum sequence length in a batch. | ||||
:return mask: torch.LongTensor, [batch_size, max_len] | :return mask: torch.LongTensor, [batch_size, max_len] | ||||
""" | """ | ||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -0,0 +1,80 @@ | |||||
import sys | |||||
import os | |||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
from fastNLP.api.processor import * | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
import _pickle as pickle | |||||
import torch | |||||
def _load(path): | |||||
with open(path, 'rb') as f: | |||||
obj = pickle.load(f) | |||||
return obj | |||||
def _load_all(src): | |||||
model_path = src | |||||
src = os.path.dirname(src) | |||||
word_v = _load(src+'/word_v.pkl') | |||||
pos_v = _load(src+'/pos_v.pkl') | |||||
tag_v = _load(src+'/tag_v.pkl') | |||||
model_args = ConfigSection() | |||||
ConfigLoader.load_config('cfg.cfg', {'model': model_args}) | |||||
model_args['word_vocab_size'] = len(word_v) | |||||
model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | |||||
model.load_state_dict(torch.load(model_path)) | |||||
return { | |||||
'word_v': word_v, | |||||
'pos_v': pos_v, | |||||
'tag_v': tag_v, | |||||
'model': model, | |||||
} | |||||
def build(load_path, save_path): | |||||
BOS = '<BOS>' | |||||
NUM = '<NUM>' | |||||
_dict = _load_all(load_path) | |||||
word_vocab = _dict['word_v'] | |||||
pos_vocab = _dict['pos_v'] | |||||
tag_vocab = _dict['tag_v'] | |||||
model = _dict['model'] | |||||
print('load model from {}'.format(load_path)) | |||||
word_seq = 'raw_word_seq' | |||||
pos_seq = 'raw_pos_seq' | |||||
# build pipeline | |||||
pipe = Pipeline() | |||||
pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) | |||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) | |||||
pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len')) | |||||
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) | |||||
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred')) | |||||
pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels')) | |||||
if not os.path.exists(save_path): | |||||
os.makedirs(save_path) | |||||
with open(save_path+'/pipeline.pkl', 'wb') as f: | |||||
torch.save(pipe, f) | |||||
print('save pipeline in {}'.format(save_path)) | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='build pipeline for parser.') | |||||
parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save') | |||||
parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe') | |||||
args = parser.parse_args() | |||||
build(args.src, args.dst) |