Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

pull/11/head
x54-729 1 year ago
parent
commit
196c864ff2
2 changed files with 8 additions and 2 deletions
  1. +4
    -1
      fastNLP/core/callbacks/__init__.py
  2. +4
    -1
      fastNLP/core/callbacks/fitlog_callback.py

+ 4
- 1
fastNLP/core/callbacks/__init__.py View File

@@ -24,7 +24,9 @@ __all__ = [


"FitlogCallback", "FitlogCallback",


"TimerCallback"
"TimerCallback",

"TopkSaver"
] ]




@@ -41,3 +43,4 @@ from .more_evaluate_callback import MoreEvaluateCallback
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback from .has_monitor_callback import ResultsMonitor, HasMonitorCallback
from .fitlog_callback import FitlogCallback from .fitlog_callback import FitlogCallback
from .timer_callback import TimerCallback from .timer_callback import TimerCallback
from .topk_saver import TopkSaver

+ 4
- 1
fastNLP/core/callbacks/fitlog_callback.py View File

@@ -40,6 +40,7 @@ class FitlogCallback(HasMonitorCallback):
self.log_exception = log_exception self.log_exception = log_exception
self.log_loss_every = log_loss_every self.log_loss_every = log_loss_every
self.avg_loss = 0 self.avg_loss = 0
self.catch_exception = False


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog
@@ -72,9 +73,11 @@ class FitlogCallback(HasMonitorCallback):
self.avg_loss = 0 self.avg_loss = 0


def on_train_end(self, trainer): def on_train_end(self, trainer):
fitlog.finish()
if not self.catch_exception:
fitlog.finish()


def on_exception(self, trainer, exception): def on_exception(self, trainer, exception):
self.catch_exception = True
fitlog.finish(status=1) fitlog.finish(status=1)
if self.log_exception: if self.log_exception:
fitlog.add_other(repr(exception), name='except_info') fitlog.add_other(repr(exception), name='except_info')

Loading…
Cancel
Save