|
-
- from fastNLP.core.callback import Callback
- import torch
- from torch import nn
-
- class OptimizerCallback(Callback):
- def __init__(self, optimizer, scheduler, update_every=4):
- super().__init__()
-
- self._optimizer = optimizer
- self.scheduler = scheduler
- self._update_every = update_every
-
- def on_backward_end(self):
- if self.step % self._update_every==0:
- # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5)
- # self._optimizer.step()
- self.scheduler.step()
- # self.model.zero_grad()
-
-
- class DevCallback(Callback):
- def __init__(self, tester, metric_key='u_f1'):
- super().__init__()
- self.tester = tester
- setattr(tester, 'verbose', 0)
-
- self.metric_key = metric_key
-
- self.record_best = False
- self.best_eval_value = 0
- self.best_eval_res = None
-
- self.best_dev_res = None # 存取dev的表现
-
- def on_valid_begin(self):
- eval_res = self.tester.test()
- metric_name = self.tester.metrics[0].__class__.__name__
- metric_value = eval_res[metric_name][self.metric_key]
- if metric_value>self.best_eval_value:
- self.best_eval_value = metric_value
- self.best_epoch = self.trainer.epoch
- self.record_best = True
- self.best_eval_res = eval_res
- self.test_eval_res = eval_res
- eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \
- self.tester._format_eval_results(eval_res)
- self.pbar.write(eval_str)
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- if self.record_best:
- self.best_dev_res = eval_result
- self.record_best = False
- if is_better_eval:
- self.best_dev_res_on_dev = eval_result
- self.best_test_res_on_dev = self.test_eval_res
- self.dev_epoch = self.epoch
-
- def on_train_end(self):
- print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch,
- self.tester._format_eval_results(self.best_eval_res),
- self.tester._format_eval_results(self.best_dev_res)))
- print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch,
- self.tester._format_eval_results(self.best_test_res_on_dev),
- self.tester._format_eval_results(self.best_dev_res_on_dev)))
|