diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index b22a6913..5bd35b7a 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -86,7 +86,7 @@ class PaddleDriver(Driver): # scaler的参数 self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) - self.grad_scaler = _grad_scaler() + self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {})) # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)