|
|
@@ -165,11 +165,11 @@ class DistTrainer: |
|
|
|
self.grad_scaler = grad_scaler |
|
|
|
|
|
|
|
self.set_grad_to_none = kwargs.get('set_grad_to_none', False) |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
if isinstance(model, DDP): |
|
|
|
self.ddp_model = model |
|
|
|
else: |
|
|
|
model.to(self.device) |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, |
|
|
@@ -182,7 +182,6 @@ class DistTrainer: |
|
|
|
self._forward_func = self.model.forward |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
|
self.optimizer = optimizer |
|
|
|
if isinstance(self.train_data, DataSet): |
|
|
|