|
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
-
- import mmcv
-
- from ..builder import PIPELINES
- from .compose import Compose
-
-
- @PIPELINES.register_module()
- class MultiScaleFlipAug:
- """Test-time augmentation with multiple scales and flipping.
-
- An example configuration is as followed:
-
- .. code-block::
-
- img_scale=[(1333, 400), (1333, 800)],
- flip=True,
- transforms=[
- dict(type='Resize', keep_ratio=True),
- dict(type='RandomFlip'),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='Pad', size_divisor=32),
- dict(type='ImageToTensor', keys=['img']),
- dict(type='Collect', keys=['img']),
- ]
-
- After MultiScaleFLipAug with above configuration, the results are wrapped
- into lists of the same length as followed:
-
- .. code-block::
-
- dict(
- img=[...],
- img_shape=[...],
- scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
- flip=[False, True, False, True]
- ...
- )
-
- Args:
- transforms (list[dict]): Transforms to apply in each augmentation.
- img_scale (tuple | list[tuple] | None): Images scales for resizing.
- scale_factor (float | list[float] | None): Scale factors for resizing.
- flip (bool): Whether apply flip augmentation. Default: False.
- flip_direction (str | list[str]): Flip augmentation directions,
- options are "horizontal", "vertical" and "diagonal". If
- flip_direction is a list, multiple flip augmentations will be
- applied. It has no effect when flip == False. Default:
- "horizontal".
- """
-
- def __init__(self,
- transforms,
- img_scale=None,
- scale_factor=None,
- flip=False,
- flip_direction='horizontal'):
- self.transforms = Compose(transforms)
- assert (img_scale is None) ^ (scale_factor is None), (
- 'Must have but only one variable can be set')
- if img_scale is not None:
- self.img_scale = img_scale if isinstance(img_scale,
- list) else [img_scale]
- self.scale_key = 'scale'
- assert mmcv.is_list_of(self.img_scale, tuple)
- else:
- self.img_scale = scale_factor if isinstance(
- scale_factor, list) else [scale_factor]
- self.scale_key = 'scale_factor'
-
- self.flip = flip
- self.flip_direction = flip_direction if isinstance(
- flip_direction, list) else [flip_direction]
- assert mmcv.is_list_of(self.flip_direction, str)
- if not self.flip and self.flip_direction != ['horizontal']:
- warnings.warn(
- 'flip_direction has no effect when flip is set to False')
- if (self.flip
- and not any([t['type'] == 'RandomFlip' for t in transforms])):
- warnings.warn(
- 'flip has no effect when RandomFlip is not in transforms')
-
- def __call__(self, results):
- """Call function to apply test time augment transforms on results.
-
- Args:l
- results (dict): Result dict contains the data to transform.
-
- Returns:
- dict[str: list]: The augmented data, where each value is wrapped
- into a list.
- """
-
- aug_data = []
- flip_args = [(False, None)]
- if self.flip:
- flip_args += [(True, direction)
- for direction in self.flip_direction]
- for scale in self.img_scale:
- for flip, direction in flip_args:
- _results = results.copy()
- _results[self.scale_key] = scale
- _results['flip'] = flip
- _results['flip_direction'] = direction
- data = self.transforms(_results)
- aug_data.append(data)
- # list of dict to dict of list
- aug_data_dict = {key: [] for key in aug_data[0]}
- for data in aug_data:
- for key, val in data.items():
- aug_data_dict[key].append(val)
- return aug_data_dict
-
- def __repr__(self):
- repr_str = self.__class__.__name__
- repr_str += f'(transforms={self.transforms}, '
- repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
- repr_str += f'flip_direction={self.flip_direction})'
- return repr_str
|