|
|
@@ -32,6 +32,7 @@ from .utils import _build_args |
|
|
|
from .utils import _build_fp16_env |
|
|
|
from .utils import _get_func_signature |
|
|
|
from .utils import _move_dict_value_to_device |
|
|
|
from .sampler import Sampler |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'get_local_rank', |
|
|
@@ -68,7 +69,7 @@ class DistTrainer(): |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
|
save_path=None, device='auto', |
|
|
|
fp16=False, use_tqdm=True, **kwargs): |
|
|
|
fp16=False, use_tqdm=True, sampler=None, **kwargs): |
|
|
|
r""" |
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
@@ -101,6 +102,8 @@ class DistTrainer(): |
|
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
|
:param bool fp16: 指定是否使用半精度训练。 |
|
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
|
:param Sampler sampler: 使用的sampler,如果不指定,默认使用的DistributedSampler。使用这个参数的情况一般为,明确修改了每个 |
|
|
|
rank的Dataset,使得每个rank上的dataset虽然sample数量一样多,但是sample其实不一样。 |
|
|
|
:param kwargs: 支持配置可选参数 |
|
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
@@ -108,6 +111,9 @@ class DistTrainer(): |
|
|
|
bool test_use_fp16: test时使用fp16 |
|
|
|
bool set_grad_to_none: zero_grad时将grad设为None而不是0 |
|
|
|
GradScaler gradscaler: 自定义的梯度 scaler |
|
|
|
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。 |
|
|
|
bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有 |
|
|
|
forward没用上的参数,否则应该不需要用到该参数。 |
|
|
|
""" |
|
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
|
if device == 'auto': |
|
|
@@ -126,6 +132,8 @@ class DistTrainer(): |
|
|
|
self.rank = dist.get_rank() # unique id for each process |
|
|
|
|
|
|
|
self.train_data = train_data |
|
|
|
if kwargs.get('batch_size', None): |
|
|
|
batch_size_per_gpu = int(kwargs.get('batch_size')) |
|
|
|
self.batch_size_per_gpu = int(batch_size_per_gpu) |
|
|
|
self.n_epochs = int(n_epochs) |
|
|
|
self.num_data_workers = int(num_workers) |
|
|
@@ -163,7 +171,8 @@ class DistTrainer(): |
|
|
|
# init DataParallel |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, find_unused_parameters=True) |
|
|
|
output_device=self.local_rank, |
|
|
|
find_unused_parameters=kwargs.get('find_unused_parameters', False)) |
|
|
|
else: |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
@@ -172,7 +181,17 @@ class DistTrainer(): |
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
|
self.optimizer = optimizer |
|
|
|
if isinstance(self.train_data, DataSet): |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
if sampler is None: |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
else: |
|
|
|
# sampler check |
|
|
|
if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): |
|
|
|
raise ValueError( |
|
|
|
f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") |
|
|
|
elif hasattr(sampler, 'set_batch_size'): |
|
|
|
sampler.set_batch_size(batch_size_per_gpu) |
|
|
|
self.sampler = sampler |
|
|
|
self.pin_memory = kwargs.get('pin_memory', True) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
self.n_steps = self._get_n_steps() |
|
|
@@ -191,7 +210,6 @@ class DistTrainer(): |
|
|
|
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), |
|
|
|
use_tqdm=self.test_use_tqdm) |
|
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
# Setup logging |
|
|
|
# 同步start_time |
|
|
|
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device) |
|
|
@@ -233,7 +251,8 @@ class DistTrainer(): |
|
|
|
def _get_data_iter(self, dataset): |
|
|
|
if isinstance(dataset, DataSet): |
|
|
|
return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler, |
|
|
|
num_workers=self.num_data_workers, drop_last=self.drop_last) |
|
|
|
num_workers=self.num_data_workers, drop_last=self.drop_last, |
|
|
|
pin_memory=self.pin_memory) |
|
|
|
elif isinstance(dataset, BatchIter): |
|
|
|
return dataset |
|
|
|
else: |
|
|
@@ -347,7 +366,7 @@ class DistTrainer(): |
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
self.step += 1 |
|
|
|
self.ddp_model.train() |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self.device) |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory) |
|
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
@@ -361,10 +380,9 @@ class DistTrainer(): |
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|
self.grad_scaler.scale(loss).backward() |
|
|
|
self._grad_backward(loss) |
|
|
|
self.callback_manager.on_backward_end() |
|
|
|
if self.step % self.update_every == 0: |
|
|
|
self._update() |
|
|
|
self._update() |
|
|
|
self.callback_manager.on_step_end() |
|
|
|
|
|
|
|
if self.step % self.print_every == 0: |
|
|
@@ -390,7 +408,7 @@ class DistTrainer(): |
|
|
|
self.pbar = None |
|
|
|
# ============ tqdm end ============== # |
|
|
|
|
|
|
|
def _clear_grad_opt(self, optimizer): |
|
|
|
def _clear_grad(self, optimizer): |
|
|
|
if self.set_grad_to_none: |
|
|
|
for group in optimizer.param_groups: |
|
|
|
for p in group['params']: |
|
|
@@ -399,13 +417,24 @@ class DistTrainer(): |
|
|
|
else: |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
def _grad_backward(self, loss): |
|
|
|
r"""Compute gradient with link rules. |
|
|
|
|
|
|
|
:param loss: a scalar where back-prop starts |
|
|
|
|
|
|
|
For PyTorch, just do "loss.backward()" |
|
|
|
""" |
|
|
|
if (self.step-1) % self.update_every == 0: |
|
|
|
self._clear_grad(self.optimizer) |
|
|
|
self.grad_scaler.scale(loss).backward() |
|
|
|
|
|
|
|
def _update(self): |
|
|
|
r"""Perform weight update on a model. |
|
|
|
|
|
|
|
""" |
|
|
|
self.grad_scaler.step(self.optimizer) |
|
|
|
self.grad_scaler.update() |
|
|
|
self._clear_grad_opt(self.optimizer) |
|
|
|
if self.step % self.update_every == 0: |
|
|
|
self.grad_scaler.step(self.optimizer) |
|
|
|
self.grad_scaler.update() |
|
|
|
|
|
|
|
def _data_forward(self, network, x): |
|
|
|
x = _build_args(self._forward_func, **x) |
|
|
|