Browse Source

[fix] distributed trainer

tags/v0.4.10
yunfan 5 years ago
parent
commit
cacf40366c
3 changed files with 54 additions and 16 deletions
  1. +28
    -8
      fastNLP/core/callback.py
  2. +24
    -6
      fastNLP/core/dist_trainer.py
  3. +2
    -2
      test/core/test_dist_trainer.py

+ 28
- 8
fastNLP/core/callback.py View File

@@ -324,15 +324,13 @@ class CallbackManager(Callback):
self._env = env
self.callbacks = []
if callbacks:
self.prepare_callbacks(callbacks)
self.callbacks += self.prepare_callbacks(callbacks)

def prepare_callbacks(self, callbacks):
if not callbacks:
return []
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else:
if not all([isinstance(cb, Callback) for cb in callbacks]):
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
else:
@@ -956,20 +954,42 @@ class EchoCallback(Callback):


class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
#TODO add compare & save best
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__()
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = True
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None
self.score = None

def on_validation(self):
cur_socre = self.tester.test()
cur_score = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps,
self.tester._format_eval_results(cur_socre))
self.tester._format_eval_results(cur_score))
self.logger.info(eval_str)
is_better = self.compare_better(cur_score)
if is_better:
self.score = cur_score
return cur_score, is_better

def compare_better(self, a):
if self.score is None:
return True
k = self.metric_key
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results
if self.increase_better:
return is_increase
else:
return not is_increase

def on_train_end(self):
self.logger.info('Evaluate on training ends.')


+ 24
- 6
fastNLP/core/dist_trainer.py View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
import logging
import time
from datetime import datetime, timedelta
from functools import partial

from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback
@@ -45,10 +46,12 @@ class DistTrainer():
callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1,
num_data_workers=1, drop_last=False,
dev_data=None, metrics=None,
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
log_path=None,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):
fp16='', backend=None, init_method=None,
find_unused_parameters=True):

assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
@@ -87,6 +90,7 @@ class DistTrainer():
self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master)
self.metric_key = metric_key

model.to(self.device)
optimizer = self._get_optimizer(optimizer)
@@ -103,8 +107,13 @@ class DistTrainer():
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)

# init DataParallel
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
if find_unused_parameters:
# to support old version
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank, find_unused_parameters=find_unused_parameters)
else:
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self._get_data_iter(self.train_data)
@@ -127,7 +136,8 @@ class DistTrainer():
self.cp_save_path = None

# use INFO in the master, WARN for others
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
logging.basicConfig(filename=log_path,
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if self.is_master else logging.WARN)
self.logger = logging.getLogger(__name__)
@@ -272,7 +282,15 @@ class DistTrainer():

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
self.callback_manager.on_validation()
self.callback_manager.on_valid_begin()
eval_res = self.callback_manager.on_validation()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
dist.barrier()

if self.cp_save_path and \


+ 2
- 2
test/core/test_dist_trainer.py View File

@@ -144,12 +144,12 @@ class TestDistTrainer(unittest.TestCase):
cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd))
subprocess.check_call(cmd, timeout=60.0)
subprocess.check_call(cmd)

def test_normal_run(self):
self.run_dist(1)

def test_fp16(self):
def no_test_fp16(self):
self.run_dist(2)

def test_callback(self):


Loading…
Cancel
Save