|
@@ -40,6 +40,7 @@ class FitlogCallback(HasMonitorCallback): |
|
|
self.log_exception = log_exception |
|
|
self.log_exception = log_exception |
|
|
self.log_loss_every = log_loss_every |
|
|
self.log_loss_every = log_loss_every |
|
|
self.avg_loss = 0 |
|
|
self.avg_loss = 0 |
|
|
|
|
|
self.catch_exception = False |
|
|
|
|
|
|
|
|
def on_after_trainer_initialized(self, trainer, driver): |
|
|
def on_after_trainer_initialized(self, trainer, driver): |
|
|
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog |
|
|
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog |
|
@@ -72,9 +73,11 @@ class FitlogCallback(HasMonitorCallback): |
|
|
self.avg_loss = 0 |
|
|
self.avg_loss = 0 |
|
|
|
|
|
|
|
|
def on_train_end(self, trainer): |
|
|
def on_train_end(self, trainer): |
|
|
fitlog.finish() |
|
|
|
|
|
|
|
|
if not self.catch_exception: |
|
|
|
|
|
fitlog.finish() |
|
|
|
|
|
|
|
|
def on_exception(self, trainer, exception): |
|
|
def on_exception(self, trainer, exception): |
|
|
|
|
|
self.catch_exception = True |
|
|
fitlog.finish(status=1) |
|
|
fitlog.finish(status=1) |
|
|
if self.log_exception: |
|
|
if self.log_exception: |
|
|
fitlog.add_other(repr(exception), name='except_info') |
|
|
fitlog.add_other(repr(exception), name='except_info') |