Browse Source

move model to device in DistTrainer

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
e1ed6f16e4
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      fastNLP/core/dist_trainer.py

+ 1
- 2
fastNLP/core/dist_trainer.py View File

@@ -165,11 +165,11 @@ class DistTrainer:
self.grad_scaler = grad_scaler

self.set_grad_to_none = kwargs.get('set_grad_to_none', False)

# init DataParallel
if isinstance(model, DDP):
self.ddp_model = model
else:
model.to(self.device)
if parse_version(torch.__version__)>=parse_version('1.1'):
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank,
@@ -182,7 +182,6 @@ class DistTrainer:
self._forward_func = self.model.forward
self.model.to(self.device)


optimizer = self._get_optimizer(optimizer)
self.optimizer = optimizer
if isinstance(self.train_data, DataSet):


Loading…
Cancel
Save