|
|
@@ -178,19 +178,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler |
|
|
|
# 中寻找;VAR_KEYWORD 代表 **kwargs |
|
|
|
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) |
|
|
|
if has_variadic_kwargs: |
|
|
|
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) |
|
|
|
del init_params["self"] |
|
|
|
|
|
|
|
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; |
|
|
|
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 |
|
|
|
non_default_params = {name for name, p in init_params.items() if |
|
|
|
name in instance_attrs and p.default != instance_attrs[name]} |
|
|
|
# add `dataset` as it might have been replaced with `*args` |
|
|
|
non_default_params.add("dataset") |
|
|
|
|
|
|
|
# 收集不是默认值的参数和它的值 |
|
|
|
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} |
|
|
|
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 |
|
|
|
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): |
|
|
|
if key not in init_params and key != 'self': |
|
|
|
init_params[key] = value |
|
|
|
|
|
|
|
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} |
|
|
|
reconstruct_args.update({ |
|
|
|
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, |
|
|
|
"persistent_workers": dataloader._persistent_workers, |
|
|
|