From a16a0e94f534e26c554bfb5ced63dde773de135f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 1 Jun 2022 16:19:00 +0000 Subject: [PATCH] =?UTF-8?q?driver=20=E5=9C=A8=20dataloader=20=E4=B8=BA?= =?UTF-8?q?=E7=A9=BA=E6=97=B6=E8=BF=9B=E8=A1=8C=E8=AD=A6=E5=91=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/jittor_driver/jittor_driver.py | 3 +++ fastNLP/core/drivers/paddle_driver/paddle_driver.py | 3 +++ fastNLP/core/drivers/torch_driver/torch_driver.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index e486df8e..63ac6ec4 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -71,6 +71,9 @@ class JittorDriver(Driver): def check_dataloader_legality(self, dataloader): if not isinstance(dataloader, (Dataset, JittorDataLoader)): raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index f809f9ec..a3fde3af 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -99,6 +99,9 @@ class PaddleDriver(Driver): if dataloader.batch_size is None and dataloader.batch_sampler is None: raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" "is not None") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 9449782b..841e6614 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -94,6 +94,9 @@ class TorchDriver(Driver): def check_dataloader_legality(self, dataloader): if not isinstance(dataloader, DataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers):