diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b4eb95df..7162ee69 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -87,6 +87,7 @@ from .dataset import DataSet from .tester import Tester from ._logger import logger from .utils import _check_fp16 +from ._parallel_utils import _model_contains_inner_module try: import fitlog @@ -914,7 +915,7 @@ class CheckPointCallback(Callback): log_dir = states['fitlog_save_log_dir'] fitlog.set_log_dir(log_dir, new_log=True) except: - print("Fail to recovery the fitlog states.") + logger.error("Fail to recovery the fitlog states.") def on_train_begin(self): """ @@ -929,7 +930,10 @@ class CheckPointCallback(Callback): """ if os.path.exists(os.path.expanduser(self.save_path)): states = torch.load(self.save_path) - self.model.load_state_dict(states['model']) + model = self.model + if _model_contains_inner_module(model): + model = model.module + model.load_state_dict(states['model']) self.optimizer.load_state_dict(states['optimizer']) self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 self.trainer.step = states['step'] @@ -947,7 +951,10 @@ class CheckPointCallback(Callback): :return: """ states = {} - states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} + model = self.model + if _model_contains_inner_module(model): + model = model.module + states['model'] = {name:param.cpu() for name, param in model.state_dict().items()} states['optimizer'] = self.optimizer.state_dict() states['epoch'] = self.epoch states['step'] = self.step @@ -966,12 +973,14 @@ class CheckPointCallback(Callback): except: pass torch.save(states, self.save_path) + logger.debug("Checkpoint:{} has been saved in epoch:{}.".format(self.save_path, self.epoch)) def on_train_end(self): # 训练结束,根据情况删除保存的内容 if self.delete_when_train_finish: if os.path.exists(self.save_path): os.remove(self.save_path) + logger.debug("Checkpoint:{} has been removed.".format(self.save_path)) class WarmupCallback(Callback): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 27f266fa..ba9ec850 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -19,6 +19,8 @@ import torch.nn as nn from typing import List from ._logger import logger from prettytable import PrettyTable +from ._parallel_utils import _model_contains_inner_module + try: from apex import amp except: @@ -179,7 +181,7 @@ def _save_model(model, model_name, save_dir, only_param=False): model_path = os.path.join(save_dir, model_name) if not os.path.isdir(save_dir): os.makedirs(save_dir, exist_ok=True) - if isinstance(model, nn.DataParallel): + if _model_contains_inner_module(model): model = model.module if only_param: state_dict = model.state_dict()