|
- # Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py
- import json
- import os
- from collections import namedtuple
-
- import torch
- import torch.utils.data as data
- from PIL import Image
- import numpy as np
- from torchvision.datasets import VisionDataset
-
- class Cityscapes(VisionDataset):
- """Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
-
- Args:
- root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
- split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
- mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
- transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the target and transforms it.
- """
-
- # Based on https://github.com/mcordts/cityscapesScripts
- CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
- 'has_instances', 'ignore_in_eval', 'color'])
- classes = [
- CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
- CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
- CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
- CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
- CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
- CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
- CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
- CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
- CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
- CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
- CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
- CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
- CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
- CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
- CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
- CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
- CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
- CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
- CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
- CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
- CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
- CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
- CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
- CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
- CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
- CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
- CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
- CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
- CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
- CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
- CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
- CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
- CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
- CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
- CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
- ]
-
- _TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
- _TRAIN_ID_TO_COLOR.append([0, 0, 0])
- _TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR)
- _ID_TO_TRAIN_ID = np.array([c.train_id for c in classes])
-
- def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None):
- super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
- self.root = os.path.expanduser(root)
- self.mode = mode
- self.target_type = target_type
-
- self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
- self.targets_dir = os.path.join(self.root, self.mode, split)
- self.split = split
-
- self.images = []
- self.targets = []
-
- if split not in ['train', 'test', 'val']:
- raise ValueError('Invalid split for mode! Please use split="train", split="test"'
- ' or split="val"')
-
- if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
- raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
- ' specified "split" and "mode" are inside the "root" directory')
-
- for city in os.listdir(self.images_dir):
- img_dir = os.path.join(self.images_dir, city)
- target_dir = os.path.join(self.targets_dir, city)
-
- for file_name in os.listdir(img_dir):
- self.images.append(os.path.join(img_dir, file_name))
- target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
- self._get_target_suffix(self.mode, self.target_type))
- self.targets.append(os.path.join(target_dir, target_name))
-
- @classmethod
- def encode_target(cls, target):
- if isinstance( target, torch.Tensor ):
- return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] )
- else:
- return cls._ID_TO_TRAIN_ID[target]
-
- @classmethod
- def decode_fn(cls, target):
- target[target == 255] = 19
- #target = target.astype('uint8') + 1
- return cls._TRAIN_ID_TO_COLOR[target]
-
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
- than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
- """
- image = Image.open(self.images[index]).convert('RGB')
- target = Image.open(self.targets[index])
- if self.transforms:
- image, target = self.transforms(image, target)
- target = self.encode_target(target)
- return image, target
-
- def __len__(self):
- return len(self.images)
-
- def _load_json(self, path):
- with open(path, 'r') as file:
- data = json.load(file)
- return data
-
- def _get_target_suffix(self, mode, target_type):
- if target_type == 'instance':
- return '{}_instanceIds.png'.format(mode)
- elif target_type == 'semantic':
- return '{}_labelIds.png'.format(mode)
- elif target_type == 'color':
- return '{}_color.png'.format(mode)
- elif target_type == 'polygon':
- return '{}_polygons.json'.format(mode)
- elif target_type == 'depth':
- return '{}_disparity.png'.format(mode)
|