diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 09ff860b..85903315 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -324,14 +324,14 @@ class CallbackManager(Callback): self._env = env self.callbacks = [] if callbacks: - self.prepare_callbacks(callbacks) + self.callbacks = self.prepare_callbacks(callbacks) def prepare_callbacks(self, callbacks): if not callbacks: return [] if isinstance(callbacks, list): if all([isinstance(cb, Callback) for cb in callbacks]) is True: - self.callbacks.extend(callbacks) + pass else: obj = [not isinstance(cb, Callback) for cb in callbacks][0] raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 260b93b0..57c5f56b 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -18,6 +18,7 @@ from .optimizer import Optimizer from .utils import _build_args from .utils import _move_dict_value_to_device from .utils import _get_func_signature +from pkg_resources import parse_version __all__ = [ 'get_local_rank', @@ -103,8 +104,13 @@ class DistTrainer(): model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) # init DataParallel - self.model = DDP(model, device_ids=[self.local_rank], - output_device=self.local_rank) + if parse_version(torch.__version__)>=parse_version('1.1'): + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank, find_unused_parameters=True) + else: + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank) + self.optimizer = optimizer self.sampler = DistributedSampler(self.train_data) self.data_iterator = self._get_data_iter(self.train_data)