From ec9fd32d6070330c8b8a6499113ee8d5abf91b21 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 10 Nov 2018 18:49:22 +0800 Subject: [PATCH] improve trainer: log mean and std of model params, and sum of gradients --- fastNLP/core/trainer.py | 28 +++++++++++---------- fastNLP/modules/decoder/CRF.py | 2 +- reproduction/chinese_word_segment/cws.cfg | 4 +-- reproduction/pos_tag_model/pos_tag.cfg | 4 +-- reproduction/pos_tag_model/train_pos_tag.py | 7 +++++- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d1881297..a8f0e3c2 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -17,6 +17,7 @@ from fastNLP.saver.model_saver import ModelSaver logger = create_logger(__name__, "./train_test.log") logger.disabled = True + class Trainer(object): """Operations of training a model, including data loading, gradient descent, and validation. @@ -138,9 +139,7 @@ class Trainer(object): print("training epochs started " + self.start_time) logger.info("training epochs started " + self.start_time) epoch, iters = 1, 0 - while(1): - if self.n_epochs != -1 and epoch > self.n_epochs: - break + while epoch <= self.n_epochs: logger.info("training epoch {}".format(epoch)) # prepare mini-batch iterator @@ -149,12 +148,13 @@ class Trainer(object): logger.info("prepared data iterator") # one forward and backward pass - iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) + iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, + step=iters, dev_data=dev_data) # validation if self.validate: self.valid_model() - self.save_model(self._model, 'training_model_'+self.start_time) + self.save_model(self._model, 'training_model_' + self.start_time) epoch += 1 def _train_step(self, data_iterator, network, **kwargs): @@ -171,13 +171,13 @@ class Trainer(object): loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) - # if torch.rand(1).item() < 0.001: - # print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) - # for name, p in self._model.named_parameters(): - # if p.requires_grad: - # print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) self.update() self._summary_writer.add_scalar("loss", loss.item(), global_step=step) + for name, param in self._model.named_parameters(): + if param.requires_grad: + self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) + self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) + self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: end = time.time() @@ -193,14 +193,14 @@ class Trainer(object): def valid_model(self): if self.dev_data is None: - raise RuntimeError( - "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") + raise RuntimeError( + "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") logger.info("validation started") res = self.validator.test(self._model, self.dev_data) if self.save_best_dev and self.best_eval_result(res): logger.info('save best result! {}'.format(res)) print('save best result! {}'.format(res)) - self.save_model(self._model, 'best_model_'+self.start_time) + self.save_model(self._model, 'best_model_' + self.start_time) return res def mode(self, model, is_test=False): @@ -324,10 +324,12 @@ class Trainer(object): def set_validator(self, validor): self.validator = validor + class SeqLabelTrainer(Trainer): """Trainer for Sequence Labeling """ + def __init__(self, **kwargs): print( "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index e24f4d27..30279a61 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -3,6 +3,7 @@ from torch import nn from fastNLP.modules.utils import initial_parameter + def log_sum_exp(x, dim=-1): max_value, _ = x.max(dim=dim, keepdim=True) res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value @@ -91,7 +92,6 @@ class ConditionalRandomField(nn.Module): st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] - print(score.size(), st_scores.size(), ed_scores.size()) score += st_scores + ed_scores # return [B,] return score diff --git a/reproduction/chinese_word_segment/cws.cfg b/reproduction/chinese_word_segment/cws.cfg index 033d3967..d2263353 100644 --- a/reproduction/chinese_word_segment/cws.cfg +++ b/reproduction/chinese_word_segment/cws.cfg @@ -1,6 +1,6 @@ [train] -epochs = 30 -batch_size = 64 +epochs = 40 +batch_size = 8 pickle_path = "./save/" validate = true save_best_dev = true diff --git a/reproduction/pos_tag_model/pos_tag.cfg b/reproduction/pos_tag_model/pos_tag.cfg index 2e1f37b6..2a08f6da 100644 --- a/reproduction/pos_tag_model/pos_tag.cfg +++ b/reproduction/pos_tag_model/pos_tag.cfg @@ -1,6 +1,6 @@ [train] -epochs = 5 -batch_size = 2 +epochs = 20 +batch_size = 32 pickle_path = "./save/" validate = false save_best_dev = true diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index 027358ef..8936bac8 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -6,6 +6,7 @@ from fastNLP.api.pipeline import Pipeline from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance +from fastNLP.core.optimizer import Optimizer from fastNLP.core.trainer import Trainer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader @@ -63,7 +64,11 @@ def train(): model = AdvSeqLabel(model_param) # call trainer to train - trainer = Trainer(**train_param.data) + trainer = Trainer(epochs=train_param["epochs"], + batch_size=train_param["batch_size"], + validate=False, + optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), + ) trainer.train(model, dataset) # save model & pipeline