|
- # Modified from https://github.com/pytorch/vision
- import os
- import sys
- import tarfile
- import collections
- import torch.utils.data as data
- import shutil
- import numpy as np
- from .utils import colormap
- from torchvision.datasets import VisionDataset
- import torch
- from PIL import Image
- from torchvision.datasets.utils import download_url, check_integrity
-
- DATASET_YEAR_DICT = {
- '2012aug': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
- 'filename': 'VOCtrainval_11-May-2012.tar',
- 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
- 'base_dir': 'VOCdevkit/VOC2012'
- },
- '2012': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
- 'filename': 'VOCtrainval_11-May-2012.tar',
- 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
- 'base_dir': 'VOCdevkit/VOC2012'
- },
- '2011': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
- 'filename': 'VOCtrainval_25-May-2011.tar',
- 'md5': '6c3384ef61512963050cb5d687e5bf1e',
- 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
- },
- '2010': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
- 'filename': 'VOCtrainval_03-May-2010.tar',
- 'md5': 'da459979d0c395079b5c75ee67908abb',
- 'base_dir': 'VOCdevkit/VOC2010'
- },
- '2009': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
- 'filename': 'VOCtrainval_11-May-2009.tar',
- 'md5': '59065e4b188729180974ef6572f6a212',
- 'base_dir': 'VOCdevkit/VOC2009'
- },
- '2008': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
- 'filename': 'VOCtrainval_11-May-2012.tar',
- 'md5': '2629fa636546599198acfcfbfcf1904a',
- 'base_dir': 'VOCdevkit/VOC2008'
- },
- '2007': {
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
- 'filename': 'VOCtrainval_06-Nov-2007.tar',
- 'md5': 'c52e279531787c972589f7e41ab4ae64',
- 'base_dir': 'VOCdevkit/VOC2007'
- }
- }
-
- class VOCSegmentation(VisionDataset):
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
- Args:
- root (string): Root directory of the VOC Dataset.
- year (string, optional): The dataset year, supports years 2007 to 2012.
- image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- """
- cmap = colormap()
- def __init__(self,
- root,
- year='2012',
- image_set='train',
- download=False,
- transform=None,
- target_transform=None,
- transforms=None,
- ):
- super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
-
- is_aug=False
- if year=='2012aug':
- is_aug = True
- year = '2012'
-
- self.root = os.path.expanduser(root)
- self.year = year
- self.url = DATASET_YEAR_DICT[year]['url']
- self.filename = DATASET_YEAR_DICT[year]['filename']
- self.md5 = DATASET_YEAR_DICT[year]['md5']
-
- self.image_set = image_set
- base_dir = DATASET_YEAR_DICT[year]['base_dir']
- voc_root = os.path.join(self.root, base_dir)
- image_dir = os.path.join(voc_root, 'JPEGImages')
-
- if download:
- download_extract(self.url, self.root, self.filename, self.md5)
-
- if not os.path.isdir(voc_root):
- raise RuntimeError('Dataset not found or corrupted.' +
- ' You can use download=True to download it')
-
- if is_aug and image_set=='train':
- mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
- assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
- split_f = os.path.join( self.root, 'train_aug.txt')
- else:
- mask_dir = os.path.join(voc_root, 'SegmentationClass')
- splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
- split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
-
- if not os.path.exists(split_f):
- raise ValueError(
- 'Wrong image_set entered! Please use image_set="train" '
- 'or image_set="trainval" or image_set="val"')
-
- with open(os.path.join(split_f), "r") as f:
- file_names = [x.strip() for x in f.readlines()]
-
- self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
- self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
- assert (len(self.images) == len(self.masks))
-
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is the image segmentation.
- """
- img = Image.open(self.images[index]).convert('RGB')
- target = Image.open(self.masks[index])
- if self.transforms is not None:
- img, target = self.transforms(img, target)
- return img, target.squeeze(0)
-
- def __len__(self):
- return len(self.images)
-
- @classmethod
- def decode_fn(cls, mask):
- """decode semantic mask to RGB image"""
- return cls.cmap[mask]
-
- def download_extract(url, root, filename, md5):
- download_url(url, root, filename, md5)
- with tarfile.open(os.path.join(root, filename), "r") as tar:
- tar.extractall(path=root)
-
- CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
- 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
-
- class VOCClassification(data.Dataset):
- def __init__(self,
- root,
- year='2010',
- split='train',
- download=False,
- transforms=None,
- target_transforms=None):
-
- voc_root = os.path.join(root, 'VOC{}'.format(year))
- if not os.path.isdir(voc_root):
- raise RuntimeError('Dataset not found or corrupted.' +
- ' You can use download=True to download it')
-
- self.transforms = transforms
- self.target_transforms = target_transforms
- image_dir = os.path.join(voc_root, 'JPEGImages')
- label_dir = os.path.join(voc_root, 'ImageSets/Main')
- self.labels_list = []
-
- fname = os.path.join(label_dir, '{}.txt'.format(split))
- with open(fname) as f:
- self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f]
-
- for clas in CLASSES:
- labels = []
- with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f:
- labels = [int(line.split()[1]) for line in f]
- self.labels_list.append(labels)
-
- assert (len(self.images) == len(self.labels_list[0]))
-
-
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is the image segmentation.
- """
- img = Image.open(self.images[index]).convert('RGB')
- labels = [labels[index] for labels in self.labels_list]
-
- if self.transforms is not None:
- img = self.transforms(img)
-
- if self.target_transforms is not None:
- labels = self.target_transforms(labels)
-
- return img, labels
-
- def __len__(self):
- return len(self.images)
|