Browse Source

1.修改callback prepare_callbacks, 2.让Dist_Trainer支持find_unused_parameters, 但仅在1.1以上版本有效

tags/v0.4.10
yh 6 years ago
parent
commit
af55db2019
2 changed files with 10 additions and 4 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +8
    -2
      fastNLP/core/dist_trainer.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -324,14 +324,14 @@ class CallbackManager(Callback):
self._env = env
self.callbacks = []
if callbacks:
self.prepare_callbacks(callbacks)
self.callbacks = self.prepare_callbacks(callbacks)

def prepare_callbacks(self, callbacks):
if not callbacks:
return []
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
pass
else:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")


+ 8
- 2
fastNLP/core/dist_trainer.py View File

@@ -18,6 +18,7 @@ from .optimizer import Optimizer
from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from pkg_resources import parse_version

__all__ = [
'get_local_rank',
@@ -103,8 +104,13 @@ class DistTrainer():
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)

# init DataParallel
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
if parse_version(torch.__version__)>=parse_version('1.1'):
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank, find_unused_parameters=True)
else:
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)

self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self._get_data_iter(self.train_data)


Loading…
Cancel
Save