From af55db201990d66b9e43a95e36e96b7a340e43e7 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 29 Jul 2019 23:56:53 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9callback=20prepare=5Fcallba?= =?UTF-8?q?cks,=202.=E8=AE=A9Dist=5FTrainer=E6=94=AF=E6=8C=81find=5Funused?= =?UTF-8?q?=5Fparameters,=20=E4=BD=86=E4=BB=85=E5=9C=A81.1=E4=BB=A5?= =?UTF-8?q?=E4=B8=8A=E7=89=88=E6=9C=AC=E6=9C=89=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 4 ++-- fastNLP/core/dist_trainer.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 09ff860b..85903315 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -324,14 +324,14 @@ class CallbackManager(Callback): self._env = env self.callbacks = [] if callbacks: - self.prepare_callbacks(callbacks) + self.callbacks = self.prepare_callbacks(callbacks) def prepare_callbacks(self, callbacks): if not callbacks: return [] if isinstance(callbacks, list): if all([isinstance(cb, Callback) for cb in callbacks]) is True: - self.callbacks.extend(callbacks) + pass else: obj = [not isinstance(cb, Callback) for cb in callbacks][0] raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 260b93b0..57c5f56b 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -18,6 +18,7 @@ from .optimizer import Optimizer from .utils import _build_args from .utils import _move_dict_value_to_device from .utils import _get_func_signature +from pkg_resources import parse_version __all__ = [ 'get_local_rank', @@ -103,8 +104,13 @@ class DistTrainer(): model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) # init DataParallel - self.model = DDP(model, device_ids=[self.local_rank], - output_device=self.local_rank) + if parse_version(torch.__version__)>=parse_version('1.1'): + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank, find_unused_parameters=True) + else: + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank) + self.optimizer = optimizer self.sampler = DistributedSampler(self.train_data) self.data_iterator = self._get_data_iter(self.train_data)