|
@@ -399,10 +399,11 @@ class GradientClipCallback(Callback): |
|
|
self.clip_value = clip_value |
|
|
self.clip_value = clip_value |
|
|
|
|
|
|
|
|
def on_backward_end(self): |
|
|
def on_backward_end(self): |
|
|
if self.parameters is None: |
|
|
|
|
|
self.clip_fun(self.model.parameters(), self.clip_value) |
|
|
|
|
|
else: |
|
|
|
|
|
self.clip_fun(self.parameters, self.clip_value) |
|
|
|
|
|
|
|
|
if self.step%self.update_every==0: |
|
|
|
|
|
if self.parameters is None: |
|
|
|
|
|
self.clip_fun(self.model.parameters(), self.clip_value) |
|
|
|
|
|
else: |
|
|
|
|
|
self.clip_fun(self.parameters, self.clip_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EarlyStopCallback(Callback): |
|
|
class EarlyStopCallback(Callback): |
|
|