|
|
@@ -55,7 +55,7 @@ def get_local_rank(): |
|
|
|
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py') |
|
|
|
|
|
|
|
|
|
|
|
class DistTrainer(): |
|
|
|
class DistTrainer: |
|
|
|
r""" |
|
|
|
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 |
|
|
|
|
|
|
@@ -110,7 +110,7 @@ class DistTrainer(): |
|
|
|
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 |
|
|
|
bool test_use_fp16: test时使用fp16 |
|
|
|
bool set_grad_to_none: zero_grad时将grad设为None而不是0 |
|
|
|
GradScaler gradscaler: 自定义的梯度 scaler |
|
|
|
GradScaler grad_scaler: 自定义的梯度 scaler |
|
|
|
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。 |
|
|
|
bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有 |
|
|
|
forward没用上的参数,否则应该不需要用到该参数。 |
|
|
@@ -132,6 +132,7 @@ class DistTrainer(): |
|
|
|
self.rank = dist.get_rank() # unique id for each process |
|
|
|
|
|
|
|
self.train_data = train_data |
|
|
|
self.kwargs = kwargs |
|
|
|
if kwargs.get('batch_size', None): |
|
|
|
batch_size_per_gpu = int(kwargs.get('batch_size')) |
|
|
|
self.batch_size_per_gpu = int(batch_size_per_gpu) |
|
|
@@ -158,15 +159,15 @@ class DistTrainer(): |
|
|
|
# init fp16, must before DataParallel init |
|
|
|
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) |
|
|
|
self.auto_cast = autocast |
|
|
|
user_grad_scaler = getattr(kwargs, 'gradscaler', None) |
|
|
|
user_grad_scaler = kwargs.get('grad_scaler', None) |
|
|
|
if user_grad_scaler is not None: |
|
|
|
assert self.fp16, "must set fp16=True to enable gradscaler" |
|
|
|
assert self.fp16, "must set fp16=True to enable grad_scaler" |
|
|
|
grad_scaler = user_grad_scaler |
|
|
|
else: |
|
|
|
grad_scaler = GradScaler() |
|
|
|
self.grad_scaler = grad_scaler |
|
|
|
|
|
|
|
self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True) |
|
|
|
self.set_grad_to_none = kwargs.get('set_grad_to_none', False) |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
@@ -191,7 +192,8 @@ class DistTrainer(): |
|
|
|
elif hasattr(sampler, 'set_batch_size'): |
|
|
|
sampler.set_batch_size(batch_size_per_gpu) |
|
|
|
self.sampler = sampler |
|
|
|
self.pin_memory = kwargs.get('pin_memory', True) |
|
|
|
# concerning issue from https://github.com/pytorch/pytorch/issues/57273 |
|
|
|
self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
self.n_steps = self._get_n_steps() |
|
|
@@ -199,7 +201,6 @@ class DistTrainer(): |
|
|
|
self.dev_data = dev_data |
|
|
|
self.metrics = metrics |
|
|
|
self.test_use_tqdm = True |
|
|
|
self.kwargs = kwargs |
|
|
|
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) |
|
|
|
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) |
|
|
|
|
|
|
@@ -229,22 +230,6 @@ class DistTrainer(): |
|
|
|
self.logger.info("Num of processes: {}".format(self.world_size)) |
|
|
|
self.logger.info("Use device: {}".format(device)) |
|
|
|
|
|
|
|
def _maybe_no_sync(self): |
|
|
|
""" |
|
|
|
Whenever *samples* contains more than one mini-batch, we |
|
|
|
want to accumulate gradients locally and only call |
|
|
|
all-reduce in the last backwards pass. |
|
|
|
""" |
|
|
|
i = self.step % self.update_every |
|
|
|
if ( |
|
|
|
self.world_size > 1 |
|
|
|
and hasattr(self.ddp_model, "no_sync") |
|
|
|
and i != 0 |
|
|
|
): |
|
|
|
return self.ddp_model.no_sync() |
|
|
|
else: |
|
|
|
return contextlib.ExitStack() # dummy contextmanager |
|
|
|
|
|
|
|
def _get_n_steps(self): |
|
|
|
return len(self.data_iterator) * self.n_epochs |
|
|
|
|
|
|
@@ -365,37 +350,42 @@ class DistTrainer(): |
|
|
|
self.callback_manager.on_epoch_begin() |
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
self.step += 1 |
|
|
|
self.ddp_model.train() |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory) |
|
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
|
with self.auto_cast(): |
|
|
|
prediction = self._data_forward(self.ddp_model, batch_x) |
|
|
|
# edit prediction |
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
|
|
avg_loss += loss.detach() |
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|
self._grad_backward(loss) |
|
|
|
self.callback_manager.on_backward_end() |
|
|
|
self._update() |
|
|
|
self.callback_manager.on_step_end() |
|
|
|
|
|
|
|
if self.step % self.print_every == 0: |
|
|
|
avg_loss = float(avg_loss) / self.print_every |
|
|
|
print_output = "loss:{:<6.5f}".format(avg_loss) |
|
|
|
pbar.update(self.print_every) |
|
|
|
pbar.set_postfix_str(print_output) |
|
|
|
avg_loss = 0 |
|
|
|
|
|
|
|
self.callback_manager.on_batch_end() |
|
|
|
|
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): |
|
|
|
self._do_validation() |
|
|
|
if self.step%self.update_every!=0: |
|
|
|
no_sync = self.ddp_model.no_sync |
|
|
|
else: |
|
|
|
no_sync = contextlib.ExitStack |
|
|
|
with no_sync(): |
|
|
|
self.ddp_model.train() |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory) |
|
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
|
with self.auto_cast(): |
|
|
|
prediction = self._data_forward(self.ddp_model, batch_x) |
|
|
|
# edit prediction |
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
|
|
avg_loss += loss.detach() |
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|
self._grad_backward(loss) |
|
|
|
self.callback_manager.on_backward_end() |
|
|
|
self._update() |
|
|
|
self.callback_manager.on_step_end() |
|
|
|
|
|
|
|
if self.step % self.print_every == 0: |
|
|
|
avg_loss = float(avg_loss) / self.print_every |
|
|
|
print_output = "loss:{:<6.5f}".format(avg_loss) |
|
|
|
pbar.update(self.print_every) |
|
|
|
pbar.set_postfix_str(print_output) |
|
|
|
avg_loss = 0 |
|
|
|
|
|
|
|
self.callback_manager.on_batch_end() |
|
|
|
|
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): |
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
# ================= mini-batch end ==================== # |
|
|
|
if self.validate_every < 0 and len(self.test_manager.callbacks): |
|
|
|