From 64da46b613547a5768e6b56ffe83ab11ac1caf60 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 17 Jun 2022 23:23:33 +0800 Subject: [PATCH] =?UTF-8?q?paddle=20replace=5Fbatch=5Fsampler=E5=92=8Cchec?= =?UTF-8?q?k=5Fdataloader=20=E8=B7=9F=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/jittor_driver/jittor_driver.py | 3 +- fastNLP/core/drivers/jittor_driver/utils.py | 4 +++ .../drivers/paddle_driver/paddle_driver.py | 3 +- fastNLP/core/drivers/paddle_driver/utils.py | 29 ++++++++++++------- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index c2e338bb..312f0d83 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection, nullcontext @@ -69,7 +70,7 @@ class JittorDriver(Driver): self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, (Dataset, JittorDataLoader)): + if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)): raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") if len(dataloader) == 0: logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index c75526df..af840a09 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -14,6 +14,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: @@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler): "or report this bug to us.") def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): + batch_sampler = getattr(dataloader, "sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") if isinstance(dataloader, JittorDataLoader): init_params = dict(inspect.signature(dataloader.__init__).parameters) reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 6ef0aaae..bfc26350 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,6 +19,7 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger +from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.samplers import ( ReproducibleBatchSampler, ReproducibleSampler, @@ -93,7 +94,7 @@ class PaddleDriver(Driver): self.grad_scaler.update() def check_dataloader_legality(self, dataloader): - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") if dataloader.batch_size is None and dataloader.batch_sampler is None: raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 1191b60c..be83e5fe 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -15,6 +15,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) +from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.utils import auto_param_call, paddle_to from fastNLP.core.log import logger @@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) - return auto_cast, GradScaler + return auto_cast, GradScaler def find_free_ports(num): """ @@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update({ - "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, - "persistent_workers": dataloader._persistent_workers, - }) + if isinstance(dataloader, DataLoader): + reconstruct_args.update({ + "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, + "persistent_workers": dataloader._persistent_workers, + }) # POSITIONAL_OR_KEYWORD 代表一般的参数 # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 @@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " + f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " + f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " + f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " + f"`{dataloader_self_name}`'s attribute." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; @@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " + f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." ) + # 如果没有kwargs,则保证一下只传入需要的参数 + if not isinstance(dataloader, DataLoader): + reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} return type(dataloader)(**reconstruct_args) @@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): """ 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 """ + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") new_batch_sampler = deepcopy(dataloader.batch_sampler) new_batch_sampler.sampler = new_sampler return replace_batch_sampler(dataloader, new_batch_sampler)