Browse Source

1.修复lr_schedulder的调用时机问题;2.修复replace_sampler的初始化问题

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
08416a3a6c
3 changed files with 10 additions and 23 deletions
  1. +1
    -1
      fastNLP/core/callbacks/lr_scheduler_callback.py
  2. +5
    -13
      fastNLP/core/drivers/paddle_driver/utils.py
  3. +4
    -9
      fastNLP/core/drivers/torch_driver/utils.py

+ 1
- 1
fastNLP/core/callbacks/lr_scheduler_callback.py View File

@@ -19,7 +19,7 @@ class LRSchedCallback(Callback):
self.scheduler = scheduler
self.step_on = 0 if step_on == 'batch' else 1

def on_before_optimizers_step(self, trainer, optimizers):
def on_after_optimizers_step(self, trainer, optimizers):
if self.step_on == 0:
self.scheduler.step()



+ 5
- 13
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -178,19 +178,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler
# 中寻找;VAR_KEYWORD 代表 **kwargs
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters))
del init_params["self"]

# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍;
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来
non_default_params = {name for name, p in init_params.items() if
name in instance_attrs and p.default != instance_attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_default_params.add("dataset")

# 收集不是默认值的参数和它的值
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
# persistent_workers 在类中的对应成员带有下划线,因此添加进来
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self':
init_params[key] = value

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params}
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,


+ 4
- 9
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -189,16 +189,11 @@ def replace_sampler(dataloader: "DataLoader", sampler):
# 中寻找;
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters))
del init_params["self"]
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self':
init_params[key] = value

# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍;
non_default_params = {name for name, p in init_params.items() if
name in instance_attrs and p.default != instance_attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params}
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler))

required_args = {


Loading…
Cancel
Save