|
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- import xml.etree.ElementTree as ET
-
- import mmcv
- import numpy as np
- from PIL import Image
-
- from .builder import DATASETS
- from .custom import CustomDataset
-
-
- @DATASETS.register_module()
- class XMLDataset(CustomDataset):
- """XML dataset for detection.
-
- Args:
- min_size (int | float, optional): The minimum size of bounding
- boxes in the images. If the size of a bounding box is less than
- ``min_size``, it would be add to ignored field.
- img_subdir (str): Subdir where images are stored. Default: JPEGImages.
- ann_subdir (str): Subdir where annotations are. Default: Annotations.
- """
-
- def __init__(self,
- min_size=None,
- img_subdir='JPEGImages',
- ann_subdir='Annotations',
- **kwargs):
- assert self.CLASSES or kwargs.get(
- 'classes', None), 'CLASSES in `XMLDataset` can not be None.'
- self.img_subdir = img_subdir
- self.ann_subdir = ann_subdir
- super(XMLDataset, self).__init__(**kwargs)
- self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
- self.min_size = min_size
-
- def load_annotations(self, ann_file):
- """Load annotation from XML style ann_file.
-
- Args:
- ann_file (str): Path of XML file.
-
- Returns:
- list[dict]: Annotation info from XML file.
- """
-
- data_infos = []
- img_ids = mmcv.list_from_file(ann_file)
- for img_id in img_ids:
- filename = osp.join(self.img_subdir, f'{img_id}.jpg')
- xml_path = osp.join(self.img_prefix, self.ann_subdir,
- f'{img_id}.xml')
- tree = ET.parse(xml_path)
- root = tree.getroot()
- size = root.find('size')
- if size is not None:
- width = int(size.find('width').text)
- height = int(size.find('height').text)
- else:
- img_path = osp.join(self.img_prefix, filename)
- img = Image.open(img_path)
- width, height = img.size
- data_infos.append(
- dict(id=img_id, filename=filename, width=width, height=height))
-
- return data_infos
-
- def _filter_imgs(self, min_size=32):
- """Filter images too small or without annotation."""
- valid_inds = []
- for i, img_info in enumerate(self.data_infos):
- if min(img_info['width'], img_info['height']) < min_size:
- continue
- if self.filter_empty_gt:
- img_id = img_info['id']
- xml_path = osp.join(self.img_prefix, self.ann_subdir,
- f'{img_id}.xml')
- tree = ET.parse(xml_path)
- root = tree.getroot()
- for obj in root.findall('object'):
- name = obj.find('name').text
- if name in self.CLASSES:
- valid_inds.append(i)
- break
- else:
- valid_inds.append(i)
- return valid_inds
-
- def get_ann_info(self, idx):
- """Get annotation from XML file by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- dict: Annotation info of specified index.
- """
-
- img_id = self.data_infos[idx]['id']
- xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
- tree = ET.parse(xml_path)
- root = tree.getroot()
- bboxes = []
- labels = []
- bboxes_ignore = []
- labels_ignore = []
- for obj in root.findall('object'):
- name = obj.find('name').text
- if name not in self.CLASSES:
- continue
- label = self.cat2label[name]
- difficult = obj.find('difficult')
- difficult = 0 if difficult is None else int(difficult.text)
- bnd_box = obj.find('bndbox')
- # TODO: check whether it is necessary to use int
- # Coordinates may be float type
- bbox = [
- int(float(bnd_box.find('xmin').text)),
- int(float(bnd_box.find('ymin').text)),
- int(float(bnd_box.find('xmax').text)),
- int(float(bnd_box.find('ymax').text))
- ]
- ignore = False
- if self.min_size:
- assert not self.test_mode
- w = bbox[2] - bbox[0]
- h = bbox[3] - bbox[1]
- if w < self.min_size or h < self.min_size:
- ignore = True
- if difficult or ignore:
- bboxes_ignore.append(bbox)
- labels_ignore.append(label)
- else:
- bboxes.append(bbox)
- labels.append(label)
- if not bboxes:
- bboxes = np.zeros((0, 4))
- labels = np.zeros((0, ))
- else:
- bboxes = np.array(bboxes, ndmin=2) - 1
- labels = np.array(labels)
- if not bboxes_ignore:
- bboxes_ignore = np.zeros((0, 4))
- labels_ignore = np.zeros((0, ))
- else:
- bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
- labels_ignore = np.array(labels_ignore)
- ann = dict(
- bboxes=bboxes.astype(np.float32),
- labels=labels.astype(np.int64),
- bboxes_ignore=bboxes_ignore.astype(np.float32),
- labels_ignore=labels_ignore.astype(np.int64))
- return ann
-
- def get_cat_ids(self, idx):
- """Get category ids in XML file by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- list[int]: All categories in the image of specified index.
- """
-
- cat_ids = []
- img_id = self.data_infos[idx]['id']
- xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
- tree = ET.parse(xml_path)
- root = tree.getroot()
- for obj in root.findall('object'):
- name = obj.find('name').text
- if name not in self.CLASSES:
- continue
- label = self.cat2label[name]
- cat_ids.append(label)
-
- return cat_ids
|