From 40b8016e98ea521e545e45606424039901dfb5ba Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 28 Jun 2022 23:29:42 +0800 Subject: [PATCH] =?UTF-8?q?PaddleDriver=20=E5=8F=AF=E4=BB=A5=E4=BC=A0?= =?UTF-8?q?=E5=85=A5Gradscaler=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/paddle_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index b22a6913..5bd35b7a 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -86,7 +86,7 @@ class PaddleDriver(Driver): # scaler的参数 self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) - self.grad_scaler = _grad_scaler() + self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {})) # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)