|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmcv.runner.hooks import HOOKS, Hook
-
-
- @HOOKS.register_module()
- class CheckInvalidLossHook(Hook):
- """Check invalid loss hook.
-
- This hook will regularly check whether the loss is valid
- during training.
-
- Args:
- interval (int): Checking interval (every k iterations).
- Default: 50.
- """
-
- def __init__(self, interval=50):
- self.interval = interval
-
- def after_train_iter(self, runner):
- if self.every_n_iters(runner, self.interval):
- assert torch.isfinite(runner.outputs['loss']), \
- runner.logger.info('loss become infinite or NaN!')
|