From 08416a3a6c4e6b5afbd8cc71773416fc23ea2bc0 Mon Sep 17 00:00:00 2001 From: yhcc Date: Sun, 5 Jun 2022 19:51:27 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8Dlr=5Fschedulder=E7=9A=84?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E6=97=B6=E6=9C=BA=E9=97=AE=E9=A2=98;2.?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dreplace=5Fsampler=E7=9A=84=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/callbacks/lr_scheduler_callback.py | 2 +- fastNLP/core/drivers/paddle_driver/utils.py | 18 +++++------------- fastNLP/core/drivers/torch_driver/utils.py | 13 ++++--------- 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/fastNLP/core/callbacks/lr_scheduler_callback.py b/fastNLP/core/callbacks/lr_scheduler_callback.py index 37d089bd..a71428ca 100644 --- a/fastNLP/core/callbacks/lr_scheduler_callback.py +++ b/fastNLP/core/callbacks/lr_scheduler_callback.py @@ -19,7 +19,7 @@ class LRSchedCallback(Callback): self.scheduler = scheduler self.step_on = 0 if step_on == 'batch' else 1 - def on_before_optimizers_step(self, trainer, optimizers): + def on_after_optimizers_step(self, trainer, optimizers): if self.step_on == 0: self.scheduler.step() diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index b1815fbd..e53f4066 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -178,19 +178,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler # 中寻找;VAR_KEYWORD 代表 **kwargs has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: - init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) - del init_params["self"] - - # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; - # 将同时在实例名和参数名中出现且不是默认值的参数收集起来 - non_default_params = {name for name, p in init_params.items() if - name in instance_attrs and p.default != instance_attrs[name]} - # add `dataset` as it might have been replaced with `*args` - non_default_params.add("dataset") - - # 收集不是默认值的参数和它的值 - reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - # persistent_workers 在类中的对应成员带有下划线,因此添加进来 + for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): + if key not in init_params and key != 'self': + init_params[key] = value + + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} reconstruct_args.update({ "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, "persistent_workers": dataloader._persistent_workers, diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 2d13a8e8..a874bf3b 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -189,16 +189,11 @@ def replace_sampler(dataloader: "DataLoader", sampler): # 中寻找; has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: - init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) - del init_params["self"] + for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): + if key not in init_params and key != 'self': + init_params[key] = value - # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; - non_default_params = {name for name, p in init_params.items() if - name in instance_attrs and p.default != instance_attrs[name]} - # add `dataset` as it might have been replaced with `*args` - non_default_params.add("dataset") - - reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) required_args = {