Browse Source

PaddleDriver 可以传入Gradscaler参数

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
40b8016e98
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/core/drivers/paddle_driver/paddle_driver.py

+ 1
- 1
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -86,7 +86,7 @@ class PaddleDriver(Driver):

# scaler的参数
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()
self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {}))

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)


Loading…
Cancel
Save