From 69a6cbbf091d995c69332e96212dc29d0f7a444b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 6 Dec 2020 00:42:42 +0800 Subject: [PATCH] bug fix for Trainer fp16 --- fastNLP/core/callback.py | 5 ++++- fastNLP/core/trainer.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 91c888df..e04f278e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -482,7 +482,10 @@ class GradientClipCallback(Callback): if self.step%self.update_every==0: if self.trainer.fp16: self.grad_scaler.unscale_(self.optimizer) - self.clip_fun(self.parameters, self.clip_value) + if self.parameters is not None: + self.clip_fun(self.parameters, self.clip_value) + else: + self.clip_fun(self.model.parameters(), self.clip_value) class EarlyStopCallback(Callback): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d9731217..a2b9e8dd 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -526,6 +526,7 @@ class Trainer(object): # check fp16相关的设置 self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) + self.grad_scaler = _grad_scaler() if self.fp16: _can_use_fp16(device=device, model=model, func=self._forward_func) grad_scaler = kwargs.get('grad_scaler', None)