@@ -956,20 +956,42 @@ class EchoCallback(Callback): | |||
class TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ | |||
#TODO add compare & save best | |||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | |||
super(TesterCallback, self).__init__() | |||
self.tester = Tester(data, model, | |||
metrics=metrics, batch_size=batch_size, | |||
num_workers=num_workers, verbose=0) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
self.increase_better = True | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
else: | |||
self.metric_key = None | |||
self.score = None | |||
def on_validation(self): | |||
cur_socre = self.tester.test() | |||
cur_score = self.tester.test() | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||
self.tester._format_eval_results(cur_socre)) | |||
self.tester._format_eval_results(cur_score)) | |||
self.logger.info(eval_str) | |||
is_better = self.compare_better(cur_score) | |||
if is_better: | |||
self.score = cur_score | |||
return cur_score, is_better | |||
def compare_better(self, a): | |||
if self.score is None: | |||
return True | |||
k = self.metric_key | |||
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results | |||
if self.increase_better: | |||
return is_increase | |||
else: | |||
return not is_increase | |||
def on_train_end(self): | |||
self.logger.info('Evaluate on training ends.') | |||
@@ -9,6 +9,7 @@ from tqdm import tqdm | |||
import logging | |||
import time | |||
from datetime import datetime, timedelta | |||
from functools import partial | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
@@ -46,8 +47,9 @@ class DistTrainer(): | |||
callbacks_all=None, callbacks_master=None, | |||
batch_size_per_gpu=8, n_epochs=1, | |||
num_data_workers=1, drop_last=False, | |||
dev_data=None, metrics=None, | |||
dev_data=None, metrics=None, metric_key=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
log_path=None, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
@@ -88,6 +90,7 @@ class DistTrainer(): | |||
self.callback_manager = DistCallbackManager( | |||
env={"trainer": self}, callbacks_all=callbacks_all, | |||
callbacks_master=callbacks_master) | |||
self.metric_key = metric_key | |||
model.to(self.device) | |||
optimizer = self._get_optimizer(optimizer) | |||
@@ -133,7 +136,8 @@ class DistTrainer(): | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
logging.basicConfig(filename=log_path, | |||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
datefmt='%m/%d/%Y %H:%M:%S', | |||
level=logging.INFO if self.is_master else logging.WARN) | |||
self.logger = logging.getLogger(__name__) | |||
@@ -278,7 +282,15 @@ class DistTrainer(): | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||
self.callback_manager.on_validation() | |||
self.callback_manager.on_valid_begin() | |||
eval_res = self.callback_manager.on_validation() | |||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||
if len(eval_res): | |||
eval_res, is_better = list(zip(*eval_res)) | |||
else: | |||
eval_res, is_better = None, None | |||
self.callback_manager.on_valid_end( | |||
eval_res, self.metric_key, self.optimizer, is_better) | |||
dist.barrier() | |||
if self.cp_save_path and \ | |||
@@ -144,12 +144,12 @@ class TestDistTrainer(unittest.TestCase): | |||
cmd = ['python', '-m', 'torch.distributed.launch', | |||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | |||
print(' '.join(cmd)) | |||
subprocess.check_call(cmd, timeout=60.0) | |||
subprocess.check_call(cmd) | |||
def test_normal_run(self): | |||
self.run_dist(1) | |||
def test_fp16(self): | |||
def no_test_fp16(self): | |||
self.run_dist(2) | |||
def test_callback(self): | |||