diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index e486df8e..63ac6ec4 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -71,6 +71,9 @@ class JittorDriver(Driver): def check_dataloader_legality(self, dataloader): if not isinstance(dataloader, (Dataset, JittorDataLoader)): raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index f809f9ec..a3fde3af 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -99,6 +99,9 @@ class PaddleDriver(Driver): if dataloader.batch_size is None and dataloader.batch_sampler is None: raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" "is not None") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 9449782b..841e6614 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -94,6 +94,9 @@ class TorchDriver(Driver): def check_dataloader_legality(self, dataloader): if not isinstance(dataloader, DataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") + if len(dataloader) == 0: + logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " + "may cause some unexpected exceptions.", once=True) @staticmethod def _check_optimizer_legality(optimizers):