Browse Source

修改GradientClipCallback中parameter的存储方式,防止在torch 1.5版本中报错

tags/v0.5.5
yh_cc 4 years ago
parent
commit
9b45e9c0b5
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      fastNLP/core/callback.py

+ 4
- 1
fastNLP/core/callback.py View File

@@ -464,7 +464,10 @@ class GradientClipCallback(Callback):
self.clip_fun = nn.utils.clip_grad_value_
else:
raise ValueError("Only supports `norm` or `value` right now.")
self.parameters = list(parameters)
if parameters is not None:
self.parameters = list(parameters)
else:
self.parameters = None
self.clip_value = clip_value
def on_backward_end(self):


Loading…
Cancel
Save