Browse Source

bug fix for Trainer fp16

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
69a6cbbf09
2 changed files with 5 additions and 1 deletions
  1. +4
    -1
      fastNLP/core/callback.py
  2. +1
    -0
      fastNLP/core/trainer.py

+ 4
- 1
fastNLP/core/callback.py View File

@@ -482,7 +482,10 @@ class GradientClipCallback(Callback):
if self.step%self.update_every==0: if self.step%self.update_every==0:
if self.trainer.fp16: if self.trainer.fp16:
self.grad_scaler.unscale_(self.optimizer) self.grad_scaler.unscale_(self.optimizer)
self.clip_fun(self.parameters, self.clip_value)
if self.parameters is not None:
self.clip_fun(self.parameters, self.clip_value)
else:
self.clip_fun(self.model.parameters(), self.clip_value)




class EarlyStopCallback(Callback): class EarlyStopCallback(Callback):


+ 1
- 0
fastNLP/core/trainer.py View File

@@ -526,6 +526,7 @@ class Trainer(object):


# check fp16相关的设置 # check fp16相关的设置
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()
if self.fp16: if self.fp16:
_can_use_fp16(device=device, model=model, func=self._forward_func) _can_use_fp16(device=device, model=model, func=self._forward_func)
grad_scaler = kwargs.get('grad_scaler', None) grad_scaler = kwargs.get('grad_scaler', None)


Loading…
Cancel
Save