|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- from collections import Sequence
- from pathlib import Path
-
- import mmcv
- from mmcv import Config, DictAction
-
- from mmdet.core.utils import mask2ndarray
- from mmdet.core.visualization import imshow_det_bboxes
- from mmdet.datasets.builder import build_dataset
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Browse a dataset')
- parser.add_argument('config', help='train config file path')
- parser.add_argument(
- '--skip-type',
- type=str,
- nargs='+',
- default=['DefaultFormatBundle', 'Normalize', 'Collect'],
- help='skip some useless pipeline')
- parser.add_argument(
- '--output-dir',
- default=None,
- type=str,
- help='If there is no display interface, you can save it')
- parser.add_argument('--not-show', default=False, action='store_true')
- parser.add_argument(
- '--show-interval',
- type=float,
- default=2,
- help='the interval of show (s)')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- args = parser.parse_args()
- return args
-
-
- def retrieve_data_cfg(config_path, skip_type, cfg_options):
-
- def skip_pipeline_steps(config):
- config['pipeline'] = [
- x for x in config.pipeline if x['type'] not in skip_type
- ]
-
- cfg = Config.fromfile(config_path)
- if cfg_options is not None:
- cfg.merge_from_dict(cfg_options)
- # import modules from string list.
- if cfg.get('custom_imports', None):
- from mmcv.utils import import_modules_from_strings
- import_modules_from_strings(**cfg['custom_imports'])
- train_data_cfg = cfg.data.train
- while 'dataset' in train_data_cfg and train_data_cfg[
- 'type'] != 'MultiImageMixDataset':
- train_data_cfg = train_data_cfg['dataset']
-
- if isinstance(train_data_cfg, Sequence):
- [skip_pipeline_steps(c) for c in train_data_cfg]
- else:
- skip_pipeline_steps(train_data_cfg)
-
- return cfg
-
-
- def main():
- args = parse_args()
- cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)
-
- dataset = build_dataset(cfg.data.train)
-
- progress_bar = mmcv.ProgressBar(len(dataset))
-
- for item in dataset:
- filename = os.path.join(args.output_dir,
- Path(item['filename']).name
- ) if args.output_dir is not None else None
-
- gt_masks = item.get('gt_masks', None)
- if gt_masks is not None:
- gt_masks = mask2ndarray(gt_masks)
-
- imshow_det_bboxes(
- item['img'],
- item['gt_bboxes'],
- item['gt_labels'],
- gt_masks,
- class_names=dataset.CLASSES,
- show=not args.not_show,
- wait_time=args.show_interval,
- out_file=filename,
- bbox_color=(255, 102, 61),
- text_color=(255, 102, 61))
-
- progress_bar.update()
-
-
- if __name__ == '__main__':
- main()
|