| |
| from typing import Optional |
|
|
| import torch |
| from mmengine.hooks import Hook |
| from mmengine.runner import Runner |
|
|
| from mmdet.registry import HOOKS |
|
|
|
|
| @HOOKS.register_module() |
| class CheckInvalidLossHook(Hook): |
| """Check invalid loss hook. |
| |
| This hook will regularly check whether the loss is valid |
| during training. |
| |
| Args: |
| interval (int): Checking interval (every k iterations). |
| Default: 50. |
| """ |
|
|
| def __init__(self, interval: int = 50) -> None: |
| self.interval = interval |
|
|
| def after_train_iter(self, |
| runner: Runner, |
| batch_idx: int, |
| data_batch: Optional[dict] = None, |
| outputs: Optional[dict] = None) -> None: |
| """Regularly check whether the loss is valid every n iterations. |
| |
| Args: |
| runner (:obj:`Runner`): The runner of the training process. |
| batch_idx (int): The index of the current batch in the train loop. |
| data_batch (dict, Optional): Data from dataloader. |
| Defaults to None. |
| outputs (dict, Optional): Outputs from model. Defaults to None. |
| """ |
| if self.every_n_train_iters(runner, self.interval): |
| assert torch.isfinite(outputs['loss']), \ |
| runner.logger.info('loss become infinite or NaN!') |
|
|