Browse Source

driver 在 dataloader 为空时进行警告

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
a16a0e94f5
3 changed files with 9 additions and 0 deletions
  1. +3
    -0
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +3
    -0
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  3. +3
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 3
- 0
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -71,6 +71,9 @@ class JittorDriver(Driver):
def check_dataloader_legality(self, dataloader): def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, (Dataset, JittorDataLoader)): if not isinstance(dataloader, (Dataset, JittorDataLoader)):
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
"may cause some unexpected exceptions.", once=True)


@staticmethod @staticmethod
def _check_optimizer_legality(optimizers): def _check_optimizer_legality(optimizers):


+ 3
- 0
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -99,6 +99,9 @@ class PaddleDriver(Driver):
if dataloader.batch_size is None and dataloader.batch_sampler is None: if dataloader.batch_size is None and dataloader.batch_sampler is None:
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"
"is not None") "is not None")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
"may cause some unexpected exceptions.", once=True)


@staticmethod @staticmethod
def _check_optimizer_legality(optimizers): def _check_optimizer_legality(optimizers):


+ 3
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -94,6 +94,9 @@ class TorchDriver(Driver):
def check_dataloader_legality(self, dataloader): def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader): if not isinstance(dataloader, DataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
"may cause some unexpected exceptions.", once=True)


@staticmethod @staticmethod
def _check_optimizer_legality(optimizers): def _check_optimizer_legality(optimizers):


Loading…
Cancel
Save