|
- import os
- import torch
- import sys
- from torch import nn
-
- from fastNLP.core.callback import Callback
- from fastNLP.core.utils import _get_model_device
-
- class MyCallback(Callback):
- def __init__(self, args):
- super(MyCallback, self).__init__()
- self.args = args
- self.real_step = 0
-
- def on_step_end(self):
- if self.step % self.update_every == 0 and self.step > 0:
- self.real_step += 1
- cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5))
- for param_group in self.optimizer.param_groups:
- param_group['lr'] = cur_lr
-
- if self.real_step % 1000 == 0:
- self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step))
-
- def on_epoch_end(self):
- self.pbar.write('Epoch {} is done !!!'.format(self.epoch))
-
- def _save_model(model, model_name, save_dir, only_param=False):
- """ 存储不含有显卡信息的 state_dict 或 model
- :param model:
- :param model_name:
- :param save_dir: 保存的 directory
- :param only_param:
- :return:
- """
- model_path = os.path.join(save_dir, model_name)
- if not os.path.isdir(save_dir):
- os.makedirs(save_dir, exist_ok=True)
- if isinstance(model, nn.DataParallel):
- model = model.module
- if only_param:
- state_dict = model.state_dict()
- for key in state_dict:
- state_dict[key] = state_dict[key].cpu()
- torch.save(state_dict, model_path)
- else:
- _model_device = _get_model_device(model)
- model.cpu()
- torch.save(model, model_path)
- model.to(_model_device)
-
- class SaveModelCallback(Callback):
- """
- 由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。
- 会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型
- -save_dir
- -2019-07-03-15-06-36
- -epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
- -epoch1step40
- -2019-07-03-15-10-00
- -epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
- :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型
- :param int top: 保存dev表现top多少模型。-1为保存所有模型
- :param bool only_param: 是否只保存模型权重
- :param save_on_exception: 发生exception时,是否保存一份当时的模型
- """
- def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False):
- super().__init__()
-
- if not os.path.isdir(save_dir):
- raise IsADirectoryError("{} is not a directory.".format(save_dir))
- self.save_dir = save_dir
- if top < 0:
- self.top = sys.maxsize
- else:
- self.top = top
- self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删
-
- self.only_param = only_param
- self.save_on_exception = save_on_exception
-
- def on_train_begin(self):
- self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- metric_value = list(eval_result.values())[0][metric_key]
- self._save_this_model(metric_value)
-
- def _insert_into_ordered_save_models(self, pair):
- # pair:(metric_value, model_name)
- # 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
- index = -1
- for _pair in self._ordered_save_models:
- if _pair[0]>=pair[0] and self.trainer.increase_better:
- break
- if not self.trainer.increase_better and _pair[0]<=pair[0]:
- break
- index += 1
- save_pair = None
- if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
- save_pair = pair
- self._ordered_save_models.insert(index+1, pair)
- delete_pair = None
- if len(self._ordered_save_models)>self.top:
- delete_pair = self._ordered_save_models.pop(0)
- return save_pair, delete_pair
-
- def _save_this_model(self, metric_value):
- name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
- save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
- if save_pair:
- try:
- _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
- except Exception as e:
- print(f"The following exception:{e} happens when saves model to {self.save_dir}.")
- if delete_pair:
- try:
- delete_model_path = os.path.join(self.save_dir, delete_pair[1])
- if os.path.exists(delete_model_path):
- os.remove(delete_model_path)
- except Exception as e:
- print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
-
- def on_exception(self, exception):
- if self.save_on_exception:
- name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
- _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
-
|