|
|
@@ -86,7 +86,7 @@ class PaddleDriver(Driver): |
|
|
|
|
|
|
|
# scaler的参数 |
|
|
|
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) |
|
|
|
self.grad_scaler = _grad_scaler() |
|
|
|
self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {})) |
|
|
|
|
|
|
|
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; |
|
|
|
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) |
|
|
|