From 972185dc6cb7d5b9f4c51cb8a71e1a0fb8969ec7 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 25 Sep 2021 22:53:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0DistTrainer=E4=B8=AD=E7=9A=84?= =?UTF-8?q?batch=5Fper=5Fepoch=E5=B1=9E=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dist_trainer.py | 27 ++++++++++++++++----------- fastNLP/io/data_bundle.py | 2 +- fastNLP/io/pipe/pipe.py | 6 +++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 16af482a..1faf2d1b 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -73,7 +73,7 @@ class DistTrainer: r""" :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 - :param nn.modules model: 待训练的模型 + :param nn.modules, DDP model: 待训练的模型 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 @@ -146,7 +146,6 @@ class DistTrainer: self.losser = _prepare_losser(loss) self.fp16 = fp16 self.local_rank = get_local_rank() - self._forward_func = model.forward self.callback_manager = DistCallbackManager( env={"trainer": self}, callbacks_all=callbacks_all, callbacks_master=callbacks_master) @@ -154,8 +153,6 @@ class DistTrainer: self.metric_key = metric_key self.use_tqdm = use_tqdm - model.to(self.device) - # init fp16, must before DataParallel init autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) self.auto_cast = autocast @@ -170,15 +167,22 @@ class DistTrainer: self.set_grad_to_none = kwargs.get('set_grad_to_none', False) # init DataParallel - if parse_version(torch.__version__)>=parse_version('1.1'): - self.ddp_model = DDP(model, device_ids=[self.local_rank], - output_device=self.local_rank, - find_unused_parameters=kwargs.get('find_unused_parameters', False)) + if isinstance(model, DDP): + self.ddp_model = model else: - self.ddp_model = DDP(model, device_ids=[self.local_rank], - output_device=self.local_rank) + if parse_version(torch.__version__)>=parse_version('1.1'): + self.ddp_model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank, + find_unused_parameters=kwargs.get('find_unused_parameters', False)) + else: + self.ddp_model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank) self.model = self.ddp_model.module + self._forward_func = self.model.forward + self.model.to(self.device) + + optimizer = self._get_optimizer(optimizer) self.optimizer = optimizer if isinstance(self.train_data, DataSet): @@ -207,7 +211,7 @@ class DistTrainer: # for evaluation, only run eval on master proc if dev_data and metrics: cb = _TesterCallback( - dev_data, model, metrics, + dev_data, self.model, metrics, batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), use_tqdm=self.test_use_tqdm) self.test_manager.add_callback([cb], master=True) @@ -343,6 +347,7 @@ class DistTrainer: avg_loss = 0 data_iterator = self.data_iterator self.ddp_model.zero_grad() + self.batch_per_epoch = self.data_iterator.num_batches for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 8fd6dfcd..cfce4de4 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -32,7 +32,7 @@ class DataBundle: :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在 - 使用Pipe处理数据的时候遇到问题。 + 使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。 """ self.vocabs = vocabs or {} self.datasets = datasets or {} diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index 7416382d..0ff32d83 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -27,15 +27,15 @@ class Pipe: 对输入的DataBundle进行处理,然后返回该DataBundle。 :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 - :return: + :return: DataBundle """ raise NotImplementedError - def process_from_file(self, paths) -> DataBundle: + def process_from_file(self, paths: str) -> DataBundle: r""" 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` - :param paths: + :param str paths: :return: DataBundle """ raise NotImplementedError