|
|
@@ -17,6 +17,7 @@ from fastNLP.saver.model_saver import ModelSaver |
|
|
|
logger = create_logger(__name__, "./train_test.log") |
|
|
|
logger.disabled = True |
|
|
|
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
|
"""Operations of training a model, including data loading, gradient descent, and validation. |
|
|
|
|
|
|
@@ -138,9 +139,7 @@ class Trainer(object): |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
logger.info("training epochs started " + self.start_time) |
|
|
|
epoch, iters = 1, 0 |
|
|
|
while(1): |
|
|
|
if self.n_epochs != -1 and epoch > self.n_epochs: |
|
|
|
break |
|
|
|
while epoch <= self.n_epochs: |
|
|
|
logger.info("training epoch {}".format(epoch)) |
|
|
|
|
|
|
|
# prepare mini-batch iterator |
|
|
@@ -149,12 +148,13 @@ class Trainer(object): |
|
|
|
logger.info("prepared data iterator") |
|
|
|
|
|
|
|
# one forward and backward pass |
|
|
|
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) |
|
|
|
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, |
|
|
|
step=iters, dev_data=dev_data) |
|
|
|
|
|
|
|
# validation |
|
|
|
if self.validate: |
|
|
|
self.valid_model() |
|
|
|
self.save_model(self._model, 'training_model_'+self.start_time) |
|
|
|
self.save_model(self._model, 'training_model_' + self.start_time) |
|
|
|
epoch += 1 |
|
|
|
|
|
|
|
def _train_step(self, data_iterator, network, **kwargs): |
|
|
@@ -171,13 +171,13 @@ class Trainer(object): |
|
|
|
|
|
|
|
loss = self.get_loss(prediction, batch_y) |
|
|
|
self.grad_backward(loss) |
|
|
|
# if torch.rand(1).item() < 0.001: |
|
|
|
# print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) |
|
|
|
# for name, p in self._model.named_parameters(): |
|
|
|
# if p.requires_grad: |
|
|
|
# print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) |
|
|
|
self.update() |
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) |
|
|
|
for name, param in self._model.named_parameters(): |
|
|
|
if param.requires_grad: |
|
|
|
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) |
|
|
|
self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) |
|
|
|
self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) |
|
|
|
|
|
|
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: |
|
|
|
end = time.time() |
|
|
@@ -193,14 +193,14 @@ class Trainer(object): |
|
|
|
|
|
|
|
def valid_model(self): |
|
|
|
if self.dev_data is None: |
|
|
|
raise RuntimeError( |
|
|
|
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") |
|
|
|
raise RuntimeError( |
|
|
|
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") |
|
|
|
logger.info("validation started") |
|
|
|
res = self.validator.test(self._model, self.dev_data) |
|
|
|
if self.save_best_dev and self.best_eval_result(res): |
|
|
|
logger.info('save best result! {}'.format(res)) |
|
|
|
print('save best result! {}'.format(res)) |
|
|
|
self.save_model(self._model, 'best_model_'+self.start_time) |
|
|
|
self.save_model(self._model, 'best_model_' + self.start_time) |
|
|
|
return res |
|
|
|
|
|
|
|
def mode(self, model, is_test=False): |
|
|
@@ -324,10 +324,12 @@ class Trainer(object): |
|
|
|
def set_validator(self, validor): |
|
|
|
self.validator = validor |
|
|
|
|
|
|
|
|
|
|
|
class SeqLabelTrainer(Trainer): |
|
|
|
"""Trainer for Sequence Labeling |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
print( |
|
|
|
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") |
|
|
|