Browse Source

[bugfix] fix start_time, save&load in dist_trainer

tags/v0.5.5
yunfan 5 years ago
parent
commit
d997e5c77a
1 changed files with 25 additions and 7 deletions
  1. +25
    -7
      fastNLP/core/dist_trainer.py

+ 25
- 7
fastNLP/core/dist_trainer.py View File

@@ -19,6 +19,7 @@ from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
import time


from ._logger import logger, init_logger_dist from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
@@ -175,8 +176,12 @@ class DistTrainer():
self.test_manager.add_callback([cb], master=True) self.test_manager.add_callback([cb], master=True)


# Setup logging # Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
# 同步start_time
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
dist.broadcast(sync_time, src=0)
self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f')
# print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time))

if self.save_path: if self.save_path:
self.cp_save_path = self.save_path self.cp_save_path = self.save_path
else: else:
@@ -208,9 +213,7 @@ class DistTrainer():
return contextlib.ExitStack() # dummy contextmanager return contextlib.ExitStack() # dummy contextmanager


def _get_n_steps(self): def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int(
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs
return len(self.data_iterator) * self.n_epochs


def _get_data_iter(self, dataset): def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
@@ -432,8 +435,9 @@ class DistTrainer():
model_load = model_load.state_dict() model_load = model_load.state_dict()
self.model.load_state_dict(model_load) self.model.load_state_dict(model_load)


def _best_save_name(self):
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
def _best_save_name(self, auto_fix=True):
best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time])
return best_name


def _do_validation(self): def _do_validation(self):
with self.ddp_model.no_sync(): with self.ddp_model.no_sync():
@@ -447,12 +451,26 @@ class DistTrainer():
is_better = is_better[0] is_better = is_better[0]
else: else:
eval_res, is_better = None, None eval_res, is_better = None, None
if self.metric_key is None and eval_res is not None:
eval_res0 = list(eval_res.values())[0]
self.metric_key = list(eval_res0.keys())[0]
# logger.info('{}, {}'.format(eval_res, is_better)) # logger.info('{}, {}'.format(eval_res, is_better))
# save better model on master node # save better model on master node
if is_better is not None and self.cp_save_path: if is_better is not None and self.cp_save_path:
if is_better: if is_better:
self.save_check_point(self._best_save_name(), only_params=False) self.save_check_point(self._best_save_name(), only_params=False)
dist.barrier() dist.barrier()

if not self.is_master and self.metric_key is None:
# 主进程自动得到了metric_key,而其它进程没有
prefix = 'best_' + self.model.__class__.__name__
suffix = self.start_time
fn_list = os.listdir(self.cp_save_path)
fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)]
if len(fn_list) == 1:
best_name = fn_list[0]
self.metric_key = best_name[len(prefix):-len(suffix)].strip('_')
# print('RANK {} metric_key {}'.format(self.rank, self.metric_key))
self.callback_manager.on_valid_end( self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better) eval_res, self.metric_key, self.optimizer, is_better)
self.ddp_model.train() self.ddp_model.train()


Loading…
Cancel
Save