|
- # Modified from https://github.com/VainF/nyuv2-python-toolkit
- import os
- import torch
- import torch.utils.data as data
- from PIL import Image
- from scipy.io import loadmat
- import numpy as np
- import glob
- from torchvision import transforms
- from torchvision.datasets import VisionDataset
- import random
-
- from .utils import colormap
-
- class NYUv2(VisionDataset):
- """NYUv2 dataset
- See https://github.com/VainF/nyuv2-python-toolkit for more details.
-
- Args:
- root (string): Root directory path.
- split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'.
- target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``.
- num_classes (int, optional): The number of classes, must be 40 or 13. Default:13.
- transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
- target_transform (callable, optional): A function/transform that takes in the target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.
- """
- cmap = colormap()
- def __init__(self,
- root,
- split='train',
- target_type='semantic',
- num_classes=13,
- transforms=None,
- transform=None,
- target_transform=None):
- super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
- assert(split in ('train', 'test'))
-
- self.root = root
- self.split = split
- self.target_type = target_type
- self.num_classes = num_classes
-
- split_mat = loadmat(os.path.join(self.root, 'splits.mat'))
- idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1
-
- img_names = os.listdir( os.path.join(self.root, 'image', self.split) )
- img_names.sort()
- images_dir = os.path.join(self.root, 'image', self.split)
- self.images = [os.path.join(images_dir, name) for name in img_names]
-
- self._is_depth = False
- if self.target_type=='semantic':
- semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split)
- self.labels = [os.path.join(semantic_dir, name) for name in img_names]
- self.targets = self.labels
-
- if self.target_type=='depth':
- depth_dir = os.path.join(self.root, 'depth', self.split)
- self.depths = [os.path.join(depth_dir, name) for name in img_names]
- self.targets = self.depths
- self._is_depth = True
-
- if self.target_type=='normal':
- normal_dir = os.path.join(self.root, 'normal', self.split)
- self.normals = [os.path.join(normal_dir, name) for name in img_names]
- self.targets = self.normals
-
- def __getitem__(self, idx):
- image = Image.open(self.images[idx])
- target = Image.open(self.targets[idx])
- if self.transforms is not None:
- image, target = self.transforms( image, target )
- return image, target
-
- def __len__(self):
- return len(self.images)
-
- @classmethod
- def decode_fn(cls, mask: np.ndarray):
- """decode semantic mask to RGB image"""
- mask = mask.astype('uint8') + 1 # 255 => 0
- return cls.cmap[mask]
|