Browse Source

paddle replace_batch_sampler和check_dataloader 跟进

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
64da46b613
4 changed files with 26 additions and 13 deletions
  1. +2
    -1
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +4
    -0
      fastNLP/core/drivers/jittor_driver/utils.py
  3. +2
    -1
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  4. +18
    -11
      fastNLP/core/drivers/paddle_driver/utils.py

+ 2
- 1
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection, nullcontext
@@ -69,7 +70,7 @@ class JittorDriver(Driver):
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, (Dataset, JittorDataLoader)):
if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)):
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "


+ 4
- 0
fastNLP/core/drivers/jittor_driver/utils.py View File

@@ -14,6 +14,7 @@ from fastNLP.envs import (
FASTNLP_BACKEND_LAUNCH,
FASTNLP_GLOBAL_SEED,
)
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
@@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler):
"or report this bug to us.")

def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler):
batch_sampler = getattr(dataloader, "sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
raise RuntimeError("It should not be running here, please report a bug to us.")
if isinstance(dataloader, JittorDataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters)
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()}


+ 2
- 1
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -19,6 +19,7 @@ from fastNLP.envs import (
rank_zero_call,
)
from fastNLP.core.log import logger
from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleSampler,
@@ -93,7 +94,7 @@ class PaddleDriver(Driver):
self.grad_scaler.update()

def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if dataloader.batch_size is None and dataloader.batch_sampler is None:
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"


+ 18
- 11
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -15,6 +15,7 @@ from fastNLP.envs import (
FASTNLP_BACKEND_LAUNCH,
FASTNLP_GLOBAL_SEED,
)
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.utils import auto_param_call, paddle_to
from fastNLP.core.log import logger

@@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False):
"NOTE: your device does NOT support faster training with fp16, "
"please switch to FP32 which is likely to be faster"
)
return auto_cast, GradScaler
return auto_cast, GradScaler

def find_free_ports(num):
"""
@@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})
if isinstance(dataloader, DataLoader):
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})

# POSITIONAL_OR_KEYWORD 代表一般的参数
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
@@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
f"`{dataloader_self_name}`'s attribute."
)

# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
@@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
)
# 如果没有kwargs,则保证一下只传入需要的参数
if not isinstance(dataloader, DataLoader):
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}

return type(dataloader)(**reconstruct_args)

@@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler):
"""
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中
"""
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
raise RuntimeError("It should not be running here, please report a bug to us.")
new_batch_sampler = deepcopy(dataloader.batch_sampler)
new_batch_sampler.sampler = new_sampler
return replace_batch_sampler(dataloader, new_batch_sampler)


Loading…
Cancel
Save