Browse Source

paddle replace_batch_sampler和check_dataloader 跟进

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
64da46b613
4 changed files with 26 additions and 13 deletions
  1. +2
    -1
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +4
    -0
      fastNLP/core/drivers/jittor_driver/utils.py
  3. +2
    -1
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  4. +18
    -11
      fastNLP/core/drivers/paddle_driver/utils.py

+ 2
- 1
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.driver import Driver from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection, nullcontext from fastNLP.core.utils import apply_to_collection, nullcontext
@@ -69,7 +70,7 @@ class JittorDriver(Driver):
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)


def check_dataloader_legality(self, dataloader): def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, (Dataset, JittorDataLoader)):
if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)):
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0: if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "


+ 4
- 0
fastNLP/core/drivers/jittor_driver/utils.py View File

@@ -14,6 +14,7 @@ from fastNLP.envs import (
FASTNLP_BACKEND_LAUNCH, FASTNLP_BACKEND_LAUNCH,
FASTNLP_GLOBAL_SEED, FASTNLP_GLOBAL_SEED,
) )
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
@@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler):
"or report this bug to us.") "or report this bug to us.")


def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler):
batch_sampler = getattr(dataloader, "sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
raise RuntimeError("It should not be running here, please report a bug to us.")
if isinstance(dataloader, JittorDataLoader): if isinstance(dataloader, JittorDataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters) init_params = dict(inspect.signature(dataloader.__init__).parameters)
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()}


+ 2
- 1
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -19,6 +19,7 @@ from fastNLP.envs import (
rank_zero_call, rank_zero_call,
) )
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
ReproducibleBatchSampler, ReproducibleBatchSampler,
ReproducibleSampler, ReproducibleSampler,
@@ -93,7 +94,7 @@ class PaddleDriver(Driver):
self.grad_scaler.update() self.grad_scaler.update()


def check_dataloader_legality(self, dataloader): def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if dataloader.batch_size is None and dataloader.batch_sampler is None: if dataloader.batch_size is None and dataloader.batch_sampler is None:
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"


+ 18
- 11
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -15,6 +15,7 @@ from fastNLP.envs import (
FASTNLP_BACKEND_LAUNCH, FASTNLP_BACKEND_LAUNCH,
FASTNLP_GLOBAL_SEED, FASTNLP_GLOBAL_SEED,
) )
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.utils import auto_param_call, paddle_to from fastNLP.core.utils import auto_param_call, paddle_to
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False):
"NOTE: your device does NOT support faster training with fp16, " "NOTE: your device does NOT support faster training with fp16, "
"please switch to FP32 which is likely to be faster" "please switch to FP32 which is likely to be faster"
) )
return auto_cast, GradScaler
return auto_cast, GradScaler


def find_free_ports(num): def find_free_ports(num):
""" """
@@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
non_default_params.add("dataset") non_default_params.add("dataset")


reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})
if isinstance(dataloader, DataLoader):
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})


# POSITIONAL_OR_KEYWORD 代表一般的参数 # POSITIONAL_OR_KEYWORD 代表一般的参数
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
@@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
required_args = sorted(required_args) required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
f"`{dataloader_self_name}`'s attribute."
) )


# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
@@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
missing_kwargs = sorted(missing_kwargs) missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
) )
# 如果没有kwargs,则保证一下只传入需要的参数
if not isinstance(dataloader, DataLoader):
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}


return type(dataloader)(**reconstruct_args) return type(dataloader)(**reconstruct_args)


@@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler):
""" """
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中
""" """
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
raise RuntimeError("It should not be running here, please report a bug to us.")
new_batch_sampler = deepcopy(dataloader.batch_sampler) new_batch_sampler = deepcopy(dataloader.batch_sampler)
new_batch_sampler.sampler = new_sampler new_batch_sampler.sampler = new_sampler
return replace_batch_sampler(dataloader, new_batch_sampler) return replace_batch_sampler(dataloader, new_batch_sampler)


Loading…
Cancel
Save