|
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import warnings
-
- from mmcv.cnn import VGG
- from mmcv.runner.hooks import HOOKS, Hook
-
- from mmdet.datasets.builder import PIPELINES
- from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
- from mmdet.models.dense_heads import GARPNHead, RPNHead
- from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
-
-
- def replace_ImageToTensor(pipelines):
- """Replace the ImageToTensor transform in a data pipeline to
- DefaultFormatBundle, which is normally useful in batch inference.
-
- Args:
- pipelines (list[dict]): Data pipeline configs.
-
- Returns:
- list: The new pipeline list with all ImageToTensor replaced by
- DefaultFormatBundle.
-
- Examples:
- >>> pipelines = [
- ... dict(type='LoadImageFromFile'),
- ... dict(
- ... type='MultiScaleFlipAug',
- ... img_scale=(1333, 800),
- ... flip=False,
- ... transforms=[
- ... dict(type='Resize', keep_ratio=True),
- ... dict(type='RandomFlip'),
- ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
- ... dict(type='Pad', size_divisor=32),
- ... dict(type='ImageToTensor', keys=['img']),
- ... dict(type='Collect', keys=['img']),
- ... ])
- ... ]
- >>> expected_pipelines = [
- ... dict(type='LoadImageFromFile'),
- ... dict(
- ... type='MultiScaleFlipAug',
- ... img_scale=(1333, 800),
- ... flip=False,
- ... transforms=[
- ... dict(type='Resize', keep_ratio=True),
- ... dict(type='RandomFlip'),
- ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
- ... dict(type='Pad', size_divisor=32),
- ... dict(type='DefaultFormatBundle'),
- ... dict(type='Collect', keys=['img']),
- ... ])
- ... ]
- >>> assert expected_pipelines == replace_ImageToTensor(pipelines)
- """
- pipelines = copy.deepcopy(pipelines)
- for i, pipeline in enumerate(pipelines):
- if pipeline['type'] == 'MultiScaleFlipAug':
- assert 'transforms' in pipeline
- pipeline['transforms'] = replace_ImageToTensor(
- pipeline['transforms'])
- elif pipeline['type'] == 'ImageToTensor':
- warnings.warn(
- '"ImageToTensor" pipeline is replaced by '
- '"DefaultFormatBundle" for batch inference. It is '
- 'recommended to manually replace it in the test '
- 'data pipeline in your config file.', UserWarning)
- pipelines[i] = {'type': 'DefaultFormatBundle'}
- return pipelines
-
-
- def get_loading_pipeline(pipeline):
- """Only keep loading image and annotations related configuration.
-
- Args:
- pipeline (list[dict]): Data pipeline configs.
-
- Returns:
- list[dict]: The new pipeline list with only keep
- loading image and annotations related configuration.
-
- Examples:
- >>> pipelines = [
- ... dict(type='LoadImageFromFile'),
- ... dict(type='LoadAnnotations', with_bbox=True),
- ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
- ... dict(type='RandomFlip', flip_ratio=0.5),
- ... dict(type='Normalize', **img_norm_cfg),
- ... dict(type='Pad', size_divisor=32),
- ... dict(type='DefaultFormatBundle'),
- ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
- ... ]
- >>> expected_pipelines = [
- ... dict(type='LoadImageFromFile'),
- ... dict(type='LoadAnnotations', with_bbox=True)
- ... ]
- >>> assert expected_pipelines ==\
- ... get_loading_pipeline(pipelines)
- """
- loading_pipeline_cfg = []
- for cfg in pipeline:
- obj_cls = PIPELINES.get(cfg['type'])
- # TODO:use more elegant way to distinguish loading modules
- if obj_cls is not None and obj_cls in (LoadImageFromFile,
- LoadAnnotations):
- loading_pipeline_cfg.append(cfg)
- assert len(loading_pipeline_cfg) == 2, \
- 'The data pipeline in your config file must include ' \
- 'loading image and annotations related pipeline.'
- return loading_pipeline_cfg
-
-
- @HOOKS.register_module()
- class NumClassCheckHook(Hook):
-
- def _check_head(self, runner):
- """Check whether the `num_classes` in head matches the length of
- `CLASSES` in `dataset`.
-
- Args:
- runner (obj:`EpochBasedRunner`): Epoch based Runner.
- """
- model = runner.model
- dataset = runner.data_loader.dataset
- if dataset.CLASSES is None:
- runner.logger.warning(
- f'Please set `CLASSES` '
- f'in the {dataset.__class__.__name__} and'
- f'check if it is consistent with the `num_classes` '
- f'of head')
- else:
- assert type(dataset.CLASSES) is not str, \
- (f'`CLASSES` in {dataset.__class__.__name__}'
- f'should be a tuple of str.'
- f'Add comma if number of classes is 1 as '
- f'CLASSES = ({dataset.CLASSES},)')
- for name, module in model.named_modules():
- if hasattr(module, 'num_classes') and not isinstance(
- module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
- assert module.num_classes == len(dataset.CLASSES), \
- (f'The `num_classes` ({module.num_classes}) in '
- f'{module.__class__.__name__} of '
- f'{model.__class__.__name__} does not matches '
- f'the length of `CLASSES` '
- f'{len(dataset.CLASSES)}) in '
- f'{dataset.__class__.__name__}')
-
- def before_train_epoch(self, runner):
- """Check whether the training dataset is compatible with head.
-
- Args:
- runner (obj:`EpochBasedRunner`): Epoch based Runner.
- """
- self._check_head(runner)
-
- def before_val_epoch(self, runner):
- """Check whether the dataset in val epoch is compatible with head.
-
- Args:
- runner (obj:`EpochBasedRunner`): Epoch based Runner.
- """
- self._check_head(runner)
|