|
|
@@ -464,7 +464,10 @@ class GradientClipCallback(Callback): |
|
|
|
self.clip_fun = nn.utils.clip_grad_value_ |
|
|
|
else: |
|
|
|
raise ValueError("Only supports `norm` or `value` right now.") |
|
|
|
self.parameters = list(parameters) |
|
|
|
if parameters is not None: |
|
|
|
self.parameters = list(parameters) |
|
|
|
else: |
|
|
|
self.parameters = None |
|
|
|
self.clip_value = clip_value |
|
|
|
|
|
|
|
def on_backward_end(self): |
|
|
|