@@ -19,7 +19,7 @@ class LRSchedCallback(Callback): | |||||
self.scheduler = scheduler | self.scheduler = scheduler | ||||
self.step_on = 0 if step_on == 'batch' else 1 | self.step_on = 0 if step_on == 'batch' else 1 | ||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
if self.step_on == 0: | if self.step_on == 0: | ||||
self.scheduler.step() | self.scheduler.step() | ||||
@@ -178,19 +178,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
# 中寻找;VAR_KEYWORD 代表 **kwargs | # 中寻找;VAR_KEYWORD 代表 **kwargs | ||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
del init_params["self"] | |||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
# 收集不是默认值的参数和它的值 | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
reconstruct_args.update({ | reconstruct_args.update({ | ||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | ||||
"persistent_workers": dataloader._persistent_workers, | "persistent_workers": dataloader._persistent_workers, | ||||
@@ -189,16 +189,11 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
# 中寻找; | # 中寻找; | ||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
del init_params["self"] | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | ||||
required_args = { | required_args = { | ||||