From b49c694c6203f7af63ac40f66ab27f93d0264973 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 11 Apr 2020 16:20:42 +0800 Subject: [PATCH] [update] bugfix in dist_trainer --- fastNLP/__init__.py | 6 +- fastNLP/core/__init__.py | 11 ++- fastNLP/core/_logger.py | 8 ++ fastNLP/core/callback.py | 8 +- fastNLP/core/dist_trainer.py | 139 +++++++++++++++++++---------------- 5 files changed, 101 insertions(+), 71 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index a9d7efe7..5f18561a 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -24,6 +24,9 @@ __all__ = [ "Trainer", "Tester", + + "DistTrainer", + "get_local_rank", "Callback", "GradientClipCallback", @@ -75,7 +78,8 @@ __all__ = [ "cache_results", - 'logger' + 'logger', + "init_logger_dist", ] __version__ = '0.5.0' diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 9f61ae0c..f4e42ab3 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -33,12 +33,16 @@ __all__ = [ "Tester", "Trainer", - + + "DistTrainer", + "get_local_rank", + "cache_results", "seq_len_to_mask", "get_seq_len", "logger", - + "init_logger_dist", + "Callback", "GradientClipCallback", "EarlyStopCallback", @@ -81,7 +85,7 @@ __all__ = [ "Sampler", ] -from ._logger import logger +from ._logger import logger, init_logger_dist from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ @@ -100,3 +104,4 @@ from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary from .collate_fn import ConcatCollateFn +from .dist_trainer import DistTrainer, get_local_rank diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py index 8bfea464..043a97c2 100644 --- a/fastNLP/core/_logger.py +++ b/fastNLP/core/_logger.py @@ -18,6 +18,7 @@ logger.set_stdout('tqdm', level='WARN') __all__ = [ 'logger', + 'init_logger_dist' ] import logging @@ -25,6 +26,7 @@ import logging.config import os import sys import warnings +from torch import distributed as dist ROOT_NAME = 'fastNLP' @@ -169,3 +171,9 @@ def _get_logger(name=None, level='INFO'): logger = _init_logger(path=None, level='INFO') + + +def init_logger_dist(): + global logger + rank = dist.get_rank() + logger.setLevel(logging.INFO if rank else logging.WARNING) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 04d26d9e..84c2656a 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -114,6 +114,9 @@ class Callback(object): self._trainer = None # 在Trainer内部被重新赋值 self._disabled = False + def __repr__(self): + return self.__class__.__name__ + @property def trainer(self): r""" @@ -1157,9 +1160,6 @@ class EchoCallback(Callback): class _TesterCallback(Callback): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): super(_TesterCallback, self).__init__() - if hasattr(model, 'module'): - # for data parallel model - model = model.module self.tester = Tester(data, model, metrics=metrics, batch_size=batch_size, num_workers=num_workers, verbose=0) @@ -1183,7 +1183,7 @@ class _TesterCallback(Callback): @staticmethod def _get_score(metric_dict, key): - for metric in metric_dict.items(): + for metric in metric_dict.values(): if key in metric: return metric[key] return None diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index f5c0f229..56d123f4 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -9,16 +9,18 @@ import os import time from datetime import datetime +import contextlib import torch import torch.cuda import torch.distributed as dist import torch.optim +from torch.serialization import default_restore_location from pkg_resources import parse_version from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm -from ._logger import logger +from ._logger import logger, init_logger_dist from .batch import DataSetIter, BatchIter from .callback import DistCallbackManager, CallbackException from .callback import _TesterCallback @@ -69,8 +71,8 @@ class DistTrainer(): num_workers=1, drop_last=False, dev_data=None, metrics=None, metric_key=None, update_every=1, print_every=10, validate_every=-1, - save_every=-1, save_path=None, device='auto', - fp16='', backend=None, init_method=None, use_tqdm=True): + save_path=None, device='auto', + fp16='', use_tqdm=True): r""" :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 @@ -98,20 +100,15 @@ class DistTrainer(): 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 - :param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 :param str device: 指定 device,可以是 gpu,cpu 或 auto :param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 - :param backend: 指定分布式的backend,详情参考 pytorch 文档 - :param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 """ assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' - if backend is None: - backend = 'nccl' if device == 'cuda' else 'gloo' # init distributed if device == 'cuda': @@ -120,11 +117,9 @@ class DistTrainer(): else: self.device = torch.device(device) - dist.init_process_group(backend=backend, init_method=init_method) self.world_size = dist.get_world_size() self.rank = dist.get_rank() # unique id for each process - self.model = model self.train_data = train_data self.batch_size_per_gpu = int(batch_size_per_gpu) self.n_epochs = int(n_epochs) @@ -133,12 +128,9 @@ class DistTrainer(): self.update_every = int(update_every) self.print_every = int(print_every) self.validate_every = int(validate_every) - self.save_every = int(save_every) self.save_path = save_path self.losser = _prepare_losser(loss) self.fp16 = fp16 - self.init_method = init_method - self.backend = backend self.local_rank = get_local_rank() self._forward_func = model.forward self.callback_manager = DistCallbackManager( @@ -160,11 +152,12 @@ class DistTrainer(): # init DataParallel if parse_version(torch.__version__)>=parse_version('1.1'): - self.model = DDP(model, device_ids=[self.local_rank], + self.ddp_model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True) else: - self.model = DDP(model, device_ids=[self.local_rank], + self.ddp_model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank) + self.model = self.ddp_model.module self.optimizer = optimizer self.sampler = DistributedSampler(self.train_data) @@ -177,18 +170,17 @@ class DistTrainer(): cb = _TesterCallback( dev_data, model, metrics, batch_size=batch_size_per_gpu, num_workers=num_workers) - self.test_manager.add_callback([cb], master=False) + self.test_manager.add_callback([cb], master=True) # Setup logging dist.barrier() self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') if self.save_path: - self.cp_save_path = os.path.join(self.save_path, 'checkpoints') + self.cp_save_path = self.save_path else: self.cp_save_path = None - # use INFO in the master, WARN for others - logger.setLevel(logging.INFO if self.is_master else logging.WARNING) + init_logger_dist() self.logger = logger self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( @@ -198,6 +190,22 @@ class DistTrainer(): self.logger.info("Training with fp16: {}, optimization level: {}".format( len(self.fp16) > 0, self.fp16 if self.fp16 else None)) + def _maybe_no_sync(self): + """ + Whenever *samples* contains more than one mini-batch, we + want to accumulate gradients locally and only call + all-reduce in the last backwards pass. + """ + i = self.step % self.update_every + if ( + self.world_size > 1 + and hasattr(self.ddp_model, "no_sync") + and i != 0 + ): + return self.ddp_model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + def _get_n_steps(self): batch_size = self.world_size * self.batch_size_per_gpu return (len(self.train_data) // batch_size + int( @@ -219,9 +227,9 @@ class DistTrainer(): if isinstance(optimizer, torch.optim.Optimizer): return optimizer elif isinstance(optimizer, Optimizer): - return optimizer.construct_from_pytorch(self.model.parameters()) + return optimizer.construct_from_pytorch(self.ddp_model.parameters()) elif optimizer is None: - return torch.optim.Adam(self.model.parameters(), lr=4e-3) + return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) @@ -252,8 +260,10 @@ class DistTrainer(): self.logger.info("###### Training epochs started ######") self.logger.info('Total epochs: %d'% self.n_epochs) self.logger.info('Total steps: %d'% self.n_steps) - self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) - self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) + self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu) + self.logger.info('Num of steps per update: %d' % self.update_every) + self.logger.info('Total batch_size: %d'% + (self.batch_size_per_gpu * dist.get_world_size() * self.update_every)) self.logger.info('Total num of samples: %d'% len(self.train_data)) self.logger.info("Num of callbacks for all workers: {}".format( len(self.callback_manager.callbacks_all))) @@ -290,7 +300,7 @@ class DistTrainer(): self.logger.info("###### Train finished ######") self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): - self.load_check_point('best') + self.load_check_point(self._best_save_name()) finally: pass dist.barrier() @@ -309,29 +319,32 @@ class DistTrainer(): pbar = self.pbar avg_loss = 0 data_iterator = self.data_iterator - self.model.zero_grad() + self.ddp_model.zero_grad() for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.on_epoch_begin() for batch_x, batch_y in data_iterator: - self.model.train() self.step += 1 + self.ddp_model.train() _move_dict_value_to_device(batch_x, batch_y, device=self.device) indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) - prediction = self._data_forward(self.model, batch_x) + prediction = self._data_forward(self.ddp_model, batch_x) # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y) + if self.update_every > 1: + loss = loss / self.update_every avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) + # with self._maybe_no_sync(): if self.fp16: with amp.scale_loss(loss, self.optimizer) as scale_loss: scale_loss.backward() @@ -355,17 +368,10 @@ class DistTrainer(): if (self.validate_every > 0 and self.step % self.validate_every == 0): self._do_validation() - if self.cp_save_path and \ - self.save_every > 0 and \ - self.step % self.save_every == 0: - self.save_check_point() - # ================= mini-batch end ==================== # if self.validate_every < 0: self._do_validation() - if self.save_every < 0 and self.cp_save_path: - self.save_check_point() # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # @@ -379,7 +385,7 @@ class DistTrainer(): """ if self.step % self.update_every == 0: self.optimizer.step() - self.model.zero_grad() + self.ddp_model.zero_grad() def _data_forward(self, network, x): x = _build_args(self._forward_func, **x) @@ -406,44 +412,51 @@ class DistTrainer(): def save_check_point(self, name=None, only_params=False): r"""保存当前模型""" # only master save models + if name is None: + name = 'checkpoint-{}.bin'.format(self.step) + os.makedirs(self.cp_save_path, exist_ok=True) + path = os.path.join(self.cp_save_path, name) + self.logger.info("Save checkpoint to {}".format(path)) + model_to_save = self.ddp_model.module + if only_params: + model_to_save = model_to_save.state_dict() if self.is_master: - if name is None: - name = 'checkpoint-{}.bin'.format(self.step) - os.makedirs(self.cp_save_path, exist_ok=True) - path = os.path.join(self.cp_save_path, name) - self.logger.info("Save checkpoint to {}".format(path)) - model_to_save = self.model.module - if only_params: - model_to_save = model_to_save.state_dict() torch.save(model_to_save, path) def load_check_point(self, name): path = os.path.join(self.cp_save_path, name) self.logger.info('reload best model from %s', path) - model_load = torch.load(path, map_location='cpu') + model_load = torch.load( + path, + map_location=lambda s, l: default_restore_location(s, "cpu")) if not isinstance(model_load, dict): model_load = model_load.state_dict() - self.model.module.load_state_dict(model_load) + self.model.load_state_dict(model_load) + + def _best_save_name(self): + return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) def _do_validation(self): - self.callback_manager.on_valid_begin() - # do evaluate on all nodes - eval_res = self.test_manager.on_valid_begin() - eval_res = list(filter(lambda x: x is not None, eval_res)) - if len(eval_res): - eval_res, is_better = list(zip(*eval_res)) - else: - eval_res, is_better = None, None - # save better model on master node - if self.is_master and is_better is not None and self.cp_save_path: - for i, better_flag in enumerate(is_better): - if better_flag: - # TODO to support multiple datasets to evaluate - self.save_check_point('best') - break - self.callback_manager.on_valid_end( - eval_res, self.metric_key, self.optimizer, is_better) - dist.barrier() + with self.ddp_model.no_sync(): + # 因为模型参数不更新,可以关闭同步 + self.callback_manager.on_valid_begin() + eval_res = self.test_manager.on_valid_begin() + eval_res = list(filter(lambda x: x is not None, eval_res)) + if len(eval_res): + eval_res, is_better = list(zip(*eval_res)) + eval_res = eval_res[0] + is_better = is_better[0] + else: + eval_res, is_better = None, None + # logger.info('{}, {}'.format(eval_res, is_better)) + # save better model on master node + if is_better is not None and self.cp_save_path: + if is_better: + self.save_check_point(self._best_save_name(), only_params=False) + dist.barrier() + self.callback_manager.on_valid_end( + eval_res, self.metric_key, self.optimizer, is_better) + self.ddp_model.train() def close(self): r"""关闭Trainer,销毁进程"""