|
|
@@ -73,7 +73,7 @@ class DistTrainer: |
|
|
|
r""" |
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
|
:param nn.modules model: 待训练的模型 |
|
|
|
:param nn.modules, DDP model: 待训练的模型 |
|
|
|
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 |
|
|
|
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` |
|
|
|
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 |
|
|
@@ -146,7 +146,6 @@ class DistTrainer: |
|
|
|
self.losser = _prepare_losser(loss) |
|
|
|
self.fp16 = fp16 |
|
|
|
self.local_rank = get_local_rank() |
|
|
|
self._forward_func = model.forward |
|
|
|
self.callback_manager = DistCallbackManager( |
|
|
|
env={"trainer": self}, callbacks_all=callbacks_all, |
|
|
|
callbacks_master=callbacks_master) |
|
|
@@ -154,8 +153,6 @@ class DistTrainer: |
|
|
|
self.metric_key = metric_key |
|
|
|
self.use_tqdm = use_tqdm |
|
|
|
|
|
|
|
model.to(self.device) |
|
|
|
|
|
|
|
# init fp16, must before DataParallel init |
|
|
|
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) |
|
|
|
self.auto_cast = autocast |
|
|
@@ -170,15 +167,22 @@ class DistTrainer: |
|
|
|
self.set_grad_to_none = kwargs.get('set_grad_to_none', False) |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, |
|
|
|
find_unused_parameters=kwargs.get('find_unused_parameters', False)) |
|
|
|
if isinstance(model, DDP): |
|
|
|
self.ddp_model = model |
|
|
|
else: |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, |
|
|
|
find_unused_parameters=kwargs.get('find_unused_parameters', False)) |
|
|
|
else: |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
self.model = self.ddp_model.module |
|
|
|
|
|
|
|
self._forward_func = self.model.forward |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
|
self.optimizer = optimizer |
|
|
|
if isinstance(self.train_data, DataSet): |
|
|
@@ -207,7 +211,7 @@ class DistTrainer: |
|
|
|
# for evaluation, only run eval on master proc |
|
|
|
if dev_data and metrics: |
|
|
|
cb = _TesterCallback( |
|
|
|
dev_data, model, metrics, |
|
|
|
dev_data, self.model, metrics, |
|
|
|
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), |
|
|
|
use_tqdm=self.test_use_tqdm) |
|
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
@@ -343,6 +347,7 @@ class DistTrainer: |
|
|
|
avg_loss = 0 |
|
|
|
data_iterator = self.data_iterator |
|
|
|
self.ddp_model.zero_grad() |
|
|
|
self.batch_per_epoch = self.data_iterator.num_batches |
|
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
|
self.epoch = epoch |
|
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
|