Browse Source

[bugfix] fix start_time, save&load in dist_trainer

tags/v0.5.5
yunfan 5 years ago
parent
commit
d997e5c77a
1 changed files with 25 additions and 7 deletions
  1. +25
    -7
      fastNLP/core/dist_trainer.py

+ 25
- 7
fastNLP/core/dist_trainer.py View File

@@ -19,6 +19,7 @@ from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import time

from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter
@@ -175,8 +176,12 @@ class DistTrainer():
self.test_manager.add_callback([cb], master=True)

# Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
# 同步start_time
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
dist.broadcast(sync_time, src=0)
self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f')
# print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time))

if self.save_path:
self.cp_save_path = self.save_path
else:
@@ -208,9 +213,7 @@ class DistTrainer():
return contextlib.ExitStack() # dummy contextmanager

def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int(
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs
return len(self.data_iterator) * self.n_epochs

def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet):
@@ -432,8 +435,9 @@ class DistTrainer():
model_load = model_load.state_dict()
self.model.load_state_dict(model_load)

def _best_save_name(self):
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
def _best_save_name(self, auto_fix=True):
best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time])
return best_name

def _do_validation(self):
with self.ddp_model.no_sync():
@@ -447,12 +451,26 @@ class DistTrainer():
is_better = is_better[0]
else:
eval_res, is_better = None, None
if self.metric_key is None and eval_res is not None:
eval_res0 = list(eval_res.values())[0]
self.metric_key = list(eval_res0.keys())[0]
# logger.info('{}, {}'.format(eval_res, is_better))
# save better model on master node
if is_better is not None and self.cp_save_path:
if is_better:
self.save_check_point(self._best_save_name(), only_params=False)
dist.barrier()

if not self.is_master and self.metric_key is None:
# 主进程自动得到了metric_key,而其它进程没有
prefix = 'best_' + self.model.__class__.__name__
suffix = self.start_time
fn_list = os.listdir(self.cp_save_path)
fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)]
if len(fn_list) == 1:
best_name = fn_list[0]
self.metric_key = best_name[len(prefix):-len(suffix)].strip('_')
# print('RANK {} metric_key {}'.format(self.rank, self.metric_key))
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
self.ddp_model.train()


Loading…
Cancel
Save