|
|
@@ -248,7 +248,10 @@ class GradientClipCallback(Callback): |
|
|
|
self.clip_value = clip_value |
|
|
|
|
|
|
|
def on_backward_end(self, model): |
|
|
|
self.clip_fun(model.parameters(), self.clip_value) |
|
|
|
if self.parameters is None: |
|
|
|
self.clip_fun(model.parameters(), self.clip_value) |
|
|
|
else: |
|
|
|
self.clip_fun(self.parameters, self.clip_value) |
|
|
|
|
|
|
|
|
|
|
|
class CallbackException(BaseException): |
|
|
@@ -306,7 +309,6 @@ class LRScheduler(Callback): |
|
|
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
|
self.scheduler.step() |
|
|
|
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) |
|
|
|
|
|
|
|
|
|
|
|
class ControlC(Callback): |
|
|
|