|
|
@@ -18,6 +18,7 @@ from .optimizer import Optimizer |
|
|
|
from .utils import _build_args |
|
|
|
from .utils import _move_dict_value_to_device |
|
|
|
from .utils import _get_func_signature |
|
|
|
from pkg_resources import parse_version |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'get_local_rank', |
|
|
@@ -103,8 +104,13 @@ class DistTrainer(): |
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, find_unused_parameters=True) |
|
|
|
else: |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
|
|
|
|
self.optimizer = optimizer |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
|