From d997e5c77acedb64f32e18e0b179e8ef1b7270a9 Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 13 Apr 2020 23:22:06 +0800 Subject: [PATCH] [bugfix] fix start_time, save&load in dist_trainer --- fastNLP/core/dist_trainer.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 30d68deb..726a5e60 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -19,6 +19,7 @@ from pkg_resources import parse_version from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm +import time from ._logger import logger, init_logger_dist from .batch import DataSetIter, BatchIter @@ -175,8 +176,12 @@ class DistTrainer(): self.test_manager.add_callback([cb], master=True) # Setup logging - dist.barrier() - self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') + # 同步start_time + sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device) + dist.broadcast(sync_time, src=0) + self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f') + # print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time)) + if self.save_path: self.cp_save_path = self.save_path else: @@ -208,9 +213,7 @@ class DistTrainer(): return contextlib.ExitStack() # dummy contextmanager def _get_n_steps(self): - batch_size = self.world_size * self.batch_size_per_gpu - return (len(self.train_data) // batch_size + int( - len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs + return len(self.data_iterator) * self.n_epochs def _get_data_iter(self, dataset): if isinstance(dataset, DataSet): @@ -432,8 +435,9 @@ class DistTrainer(): model_load = model_load.state_dict() self.model.load_state_dict(model_load) - def _best_save_name(self): - return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) + def _best_save_name(self, auto_fix=True): + best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time]) + return best_name def _do_validation(self): with self.ddp_model.no_sync(): @@ -447,12 +451,26 @@ class DistTrainer(): is_better = is_better[0] else: eval_res, is_better = None, None + if self.metric_key is None and eval_res is not None: + eval_res0 = list(eval_res.values())[0] + self.metric_key = list(eval_res0.keys())[0] # logger.info('{}, {}'.format(eval_res, is_better)) # save better model on master node if is_better is not None and self.cp_save_path: if is_better: self.save_check_point(self._best_save_name(), only_params=False) dist.barrier() + + if not self.is_master and self.metric_key is None: + # 主进程自动得到了metric_key,而其它进程没有 + prefix = 'best_' + self.model.__class__.__name__ + suffix = self.start_time + fn_list = os.listdir(self.cp_save_path) + fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)] + if len(fn_list) == 1: + best_name = fn_list[0] + self.metric_key = best_name[len(prefix):-len(suffix)].strip('_') + # print('RANK {} metric_key {}'.format(self.rank, self.metric_key)) self.callback_manager.on_valid_end( eval_res, self.metric_key, self.optimizer, is_better) self.ddp_model.train()