|
@@ -17,7 +17,8 @@ from tqdm import tqdm |
|
|
|
|
|
|
|
|
from ._logger import logger |
|
|
from ._logger import logger |
|
|
from .batch import DataSetIter, BatchIter |
|
|
from .batch import DataSetIter, BatchIter |
|
|
from .callback import DistCallbackManager, CallbackException, _TesterCallback |
|
|
|
|
|
|
|
|
from .callback import DistCallbackManager, CallbackException |
|
|
|
|
|
from .callback import _TesterCallback |
|
|
from .dataset import DataSet |
|
|
from .dataset import DataSet |
|
|
from .losses import _prepare_losser |
|
|
from .losses import _prepare_losser |
|
|
from .optimizer import Optimizer |
|
|
from .optimizer import Optimizer |
|
@@ -174,13 +175,13 @@ class DistTrainer(): |
|
|
cb = _TesterCallback( |
|
|
cb = _TesterCallback( |
|
|
dev_data, model, metrics, |
|
|
dev_data, model, metrics, |
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers) |
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers) |
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
|
self.test_manager.add_callback([cb], master=False) |
|
|
|
|
|
|
|
|
# Setup logging |
|
|
# Setup logging |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') |
|
|
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') |
|
|
if self.save_path: |
|
|
if self.save_path: |
|
|
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) |
|
|
|
|
|
|
|
|
self.cp_save_path = os.path.join(self.save_path, 'checkpoints') |
|
|
else: |
|
|
else: |
|
|
self.cp_save_path = None |
|
|
self.cp_save_path = None |
|
|
|
|
|
|
|
@@ -286,11 +287,11 @@ class DistTrainer(): |
|
|
results['seconds'] = round(time.time() - start_time, 2) |
|
|
results['seconds'] = round(time.time() - start_time, 2) |
|
|
self.logger.info("###### Train finished ######") |
|
|
self.logger.info("###### Train finished ######") |
|
|
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) |
|
|
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) |
|
|
if load_best_model: |
|
|
|
|
|
self.load_check_point('best_{}'.format(self.metric_key)) |
|
|
|
|
|
|
|
|
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): |
|
|
|
|
|
self.load_check_point('best') |
|
|
finally: |
|
|
finally: |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
return results |
|
|
return results |
|
|
|
|
|
|
|
|
def _train(self): |
|
|
def _train(self): |
|
@@ -417,29 +418,29 @@ class DistTrainer(): |
|
|
def load_check_point(self, name): |
|
|
def load_check_point(self, name): |
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
self.logger.info('reload best model from %s', path) |
|
|
self.logger.info('reload best model from %s', path) |
|
|
model_load = torch.load(path) |
|
|
|
|
|
|
|
|
model_load = torch.load(path, map_location='cpu') |
|
|
if not isinstance(model_load, dict): |
|
|
if not isinstance(model_load, dict): |
|
|
model_load = model_load.state_dict() |
|
|
model_load = model_load.state_dict() |
|
|
self.model.load_state_dict(model_load) |
|
|
|
|
|
|
|
|
self.model.module.load_state_dict(model_load) |
|
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
def _do_validation(self): |
|
|
self.callback_manager.on_valid_begin() |
|
|
self.callback_manager.on_valid_begin() |
|
|
|
|
|
# do evaluate on all nodes |
|
|
eval_res = self.test_manager.on_valid_begin() |
|
|
eval_res = self.test_manager.on_valid_begin() |
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
if len(eval_res): |
|
|
if len(eval_res): |
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
else: |
|
|
else: |
|
|
eval_res, is_better = None, None |
|
|
eval_res, is_better = None, None |
|
|
|
|
|
# save better model on master node |
|
|
|
|
|
if self.is_master and is_better is not None and self.cp_save_path: |
|
|
|
|
|
for i, better_flag in enumerate(is_better): |
|
|
|
|
|
if better_flag: |
|
|
|
|
|
# TODO to support multiple datasets to evaluate |
|
|
|
|
|
self.save_check_point('best') |
|
|
|
|
|
break |
|
|
self.callback_manager.on_valid_end( |
|
|
self.callback_manager.on_valid_end( |
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
|
|
|
|
|
|
# save better model |
|
|
|
|
|
for i, better_flag in enumerate(is_better): |
|
|
|
|
|
if better_flag: |
|
|
|
|
|
# TODO to support multiple datasets to evaluate |
|
|
|
|
|
name = 'best_{}'.format(self.metric_key) |
|
|
|
|
|
self.save_check_point(name) |
|
|
|
|
|
break |
|
|
|
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
def close(self): |
|
|
def close(self): |
|
|