|
|
@@ -15,6 +15,7 @@ from fastNLP.envs import ( |
|
|
|
FASTNLP_BACKEND_LAUNCH, |
|
|
|
FASTNLP_GLOBAL_SEED, |
|
|
|
) |
|
|
|
from fastNLP.core.samplers import ReproducibleBatchSampler |
|
|
|
from fastNLP.core.utils import auto_param_call, paddle_to |
|
|
|
from fastNLP.core.log import logger |
|
|
|
|
|
|
@@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): |
|
|
|
"NOTE: your device does NOT support faster training with fp16, " |
|
|
|
"please switch to FP32 which is likely to be faster" |
|
|
|
) |
|
|
|
return auto_cast, GradScaler |
|
|
|
return auto_cast, GradScaler |
|
|
|
|
|
|
|
def find_free_ports(num): |
|
|
|
""" |
|
|
@@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler |
|
|
|
non_default_params.add("dataset") |
|
|
|
|
|
|
|
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} |
|
|
|
reconstruct_args.update({ |
|
|
|
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, |
|
|
|
"persistent_workers": dataloader._persistent_workers, |
|
|
|
}) |
|
|
|
if isinstance(dataloader, DataLoader): |
|
|
|
reconstruct_args.update({ |
|
|
|
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, |
|
|
|
"persistent_workers": dataloader._persistent_workers, |
|
|
|
}) |
|
|
|
|
|
|
|
# POSITIONAL_OR_KEYWORD 代表一般的参数 |
|
|
|
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 |
|
|
@@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler |
|
|
|
required_args = sorted(required_args) |
|
|
|
dataloader_self_name = dataloader.__class__.__name__ |
|
|
|
raise Exception( |
|
|
|
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " |
|
|
|
"This would fail as some of the `__init__` arguments are not available as instance attributes. " |
|
|
|
f"The missing attributes are {required_args}. " |
|
|
|
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " |
|
|
|
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " |
|
|
|
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " |
|
|
|
f"`{dataloader_self_name}`'s attribute." |
|
|
|
) |
|
|
|
|
|
|
|
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; |
|
|
@@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler |
|
|
|
missing_kwargs = sorted(missing_kwargs) |
|
|
|
dataloader_self_name = dataloader.__class__.__name__ |
|
|
|
raise Exception( |
|
|
|
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " |
|
|
|
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " |
|
|
|
f"The missing arguments are {missing_kwargs}. " |
|
|
|
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." |
|
|
|
) |
|
|
|
# 如果没有kwargs,则保证一下只传入需要的参数 |
|
|
|
if not isinstance(dataloader, DataLoader): |
|
|
|
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} |
|
|
|
|
|
|
|
return type(dataloader)(**reconstruct_args) |
|
|
|
|
|
|
@@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): |
|
|
|
""" |
|
|
|
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 |
|
|
|
""" |
|
|
|
batch_sampler = getattr(dataloader, "batch_sampler") |
|
|
|
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): |
|
|
|
raise RuntimeError("It should not be running here, please report a bug to us.") |
|
|
|
new_batch_sampler = deepcopy(dataloader.batch_sampler) |
|
|
|
new_batch_sampler.sampler = new_sampler |
|
|
|
return replace_batch_sampler(dataloader, new_batch_sampler) |
|
|
|