| @@ -189,11 +189,18 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
| # 中寻找; | # 中寻找; | ||||
| has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
| if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
| # 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; | |||||
| for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | ||||
| if key not in init_params and key != 'self': | if key not in init_params and key != 'self': | ||||
| init_params[key] = value | init_params[key] = value | ||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
| # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
| non_default_params = {name for name, p in init_params.items() if | |||||
| name in instance_attrs and p.default != instance_attrs[name]} | |||||
| # add `dataset` as it might have been replaced with `*args` | |||||
| non_default_params.add("dataset") | |||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
| reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | ||||
| required_args = { | required_args = { | ||||