|
@@ -19,6 +19,7 @@ from pkg_resources import parse_version |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from tqdm import tqdm |
|
|
from tqdm import tqdm |
|
|
|
|
|
import time |
|
|
|
|
|
|
|
|
from ._logger import logger, init_logger_dist |
|
|
from ._logger import logger, init_logger_dist |
|
|
from .batch import DataSetIter, BatchIter |
|
|
from .batch import DataSetIter, BatchIter |
|
@@ -175,8 +176,12 @@ class DistTrainer(): |
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
|
# Setup logging |
|
|
# Setup logging |
|
|
dist.barrier() |
|
|
|
|
|
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') |
|
|
|
|
|
|
|
|
# 同步start_time |
|
|
|
|
|
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device) |
|
|
|
|
|
dist.broadcast(sync_time, src=0) |
|
|
|
|
|
self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f') |
|
|
|
|
|
# print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time)) |
|
|
|
|
|
|
|
|
if self.save_path: |
|
|
if self.save_path: |
|
|
self.cp_save_path = self.save_path |
|
|
self.cp_save_path = self.save_path |
|
|
else: |
|
|
else: |
|
@@ -208,9 +213,7 @@ class DistTrainer(): |
|
|
return contextlib.ExitStack() # dummy contextmanager |
|
|
return contextlib.ExitStack() # dummy contextmanager |
|
|
|
|
|
|
|
|
def _get_n_steps(self): |
|
|
def _get_n_steps(self): |
|
|
batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
|
|
return (len(self.train_data) // batch_size + int( |
|
|
|
|
|
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs |
|
|
|
|
|
|
|
|
return len(self.data_iterator) * self.n_epochs |
|
|
|
|
|
|
|
|
def _get_data_iter(self, dataset): |
|
|
def _get_data_iter(self, dataset): |
|
|
if isinstance(dataset, DataSet): |
|
|
if isinstance(dataset, DataSet): |
|
@@ -432,8 +435,9 @@ class DistTrainer(): |
|
|
model_load = model_load.state_dict() |
|
|
model_load = model_load.state_dict() |
|
|
self.model.load_state_dict(model_load) |
|
|
self.model.load_state_dict(model_load) |
|
|
|
|
|
|
|
|
def _best_save_name(self): |
|
|
|
|
|
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) |
|
|
|
|
|
|
|
|
def _best_save_name(self, auto_fix=True): |
|
|
|
|
|
best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time]) |
|
|
|
|
|
return best_name |
|
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
def _do_validation(self): |
|
|
with self.ddp_model.no_sync(): |
|
|
with self.ddp_model.no_sync(): |
|
@@ -447,12 +451,26 @@ class DistTrainer(): |
|
|
is_better = is_better[0] |
|
|
is_better = is_better[0] |
|
|
else: |
|
|
else: |
|
|
eval_res, is_better = None, None |
|
|
eval_res, is_better = None, None |
|
|
|
|
|
if self.metric_key is None and eval_res is not None: |
|
|
|
|
|
eval_res0 = list(eval_res.values())[0] |
|
|
|
|
|
self.metric_key = list(eval_res0.keys())[0] |
|
|
# logger.info('{}, {}'.format(eval_res, is_better)) |
|
|
# logger.info('{}, {}'.format(eval_res, is_better)) |
|
|
# save better model on master node |
|
|
# save better model on master node |
|
|
if is_better is not None and self.cp_save_path: |
|
|
if is_better is not None and self.cp_save_path: |
|
|
if is_better: |
|
|
if is_better: |
|
|
self.save_check_point(self._best_save_name(), only_params=False) |
|
|
self.save_check_point(self._best_save_name(), only_params=False) |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
if not self.is_master and self.metric_key is None: |
|
|
|
|
|
# 主进程自动得到了metric_key,而其它进程没有 |
|
|
|
|
|
prefix = 'best_' + self.model.__class__.__name__ |
|
|
|
|
|
suffix = self.start_time |
|
|
|
|
|
fn_list = os.listdir(self.cp_save_path) |
|
|
|
|
|
fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)] |
|
|
|
|
|
if len(fn_list) == 1: |
|
|
|
|
|
best_name = fn_list[0] |
|
|
|
|
|
self.metric_key = best_name[len(prefix):-len(suffix)].strip('_') |
|
|
|
|
|
# print('RANK {} metric_key {}'.format(self.rank, self.metric_key)) |
|
|
self.callback_manager.on_valid_end( |
|
|
self.callback_manager.on_valid_end( |
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
self.ddp_model.train() |
|
|
self.ddp_model.train() |
|
|