|
|
@@ -87,6 +87,7 @@ from .dataset import DataSet |
|
|
|
from .tester import Tester |
|
|
|
from ._logger import logger |
|
|
|
from .utils import _check_fp16 |
|
|
|
from ._parallel_utils import _model_contains_inner_module |
|
|
|
|
|
|
|
try: |
|
|
|
import fitlog |
|
|
@@ -914,7 +915,7 @@ class CheckPointCallback(Callback): |
|
|
|
log_dir = states['fitlog_save_log_dir'] |
|
|
|
fitlog.set_log_dir(log_dir, new_log=True) |
|
|
|
except: |
|
|
|
print("Fail to recovery the fitlog states.") |
|
|
|
logger.error("Fail to recovery the fitlog states.") |
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
|
""" |
|
|
@@ -929,7 +930,10 @@ class CheckPointCallback(Callback): |
|
|
|
""" |
|
|
|
if os.path.exists(os.path.expanduser(self.save_path)): |
|
|
|
states = torch.load(self.save_path) |
|
|
|
self.model.load_state_dict(states['model']) |
|
|
|
model = self.model |
|
|
|
if _model_contains_inner_module(model): |
|
|
|
model = model.module |
|
|
|
model.load_state_dict(states['model']) |
|
|
|
self.optimizer.load_state_dict(states['optimizer']) |
|
|
|
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 |
|
|
|
self.trainer.step = states['step'] |
|
|
@@ -947,7 +951,10 @@ class CheckPointCallback(Callback): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
states = {} |
|
|
|
states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} |
|
|
|
model = self.model |
|
|
|
if _model_contains_inner_module(model): |
|
|
|
model = model.module |
|
|
|
states['model'] = {name:param.cpu() for name, param in model.state_dict().items()} |
|
|
|
states['optimizer'] = self.optimizer.state_dict() |
|
|
|
states['epoch'] = self.epoch |
|
|
|
states['step'] = self.step |
|
|
@@ -966,12 +973,14 @@ class CheckPointCallback(Callback): |
|
|
|
except: |
|
|
|
pass |
|
|
|
torch.save(states, self.save_path) |
|
|
|
logger.debug("Checkpoint:{} has been saved in epoch:{}.".format(self.save_path, self.epoch)) |
|
|
|
|
|
|
|
def on_train_end(self): |
|
|
|
# 训练结束,根据情况删除保存的内容 |
|
|
|
if self.delete_when_train_finish: |
|
|
|
if os.path.exists(self.save_path): |
|
|
|
os.remove(self.save_path) |
|
|
|
logger.debug("Checkpoint:{} has been removed.".format(self.save_path)) |
|
|
|
|
|
|
|
|
|
|
|
class WarmupCallback(Callback): |
|
|
|