You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

voc.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Modified from https://github.com/pytorch/vision
  2. import os
  3. import sys
  4. import tarfile
  5. import collections
  6. import torch.utils.data as data
  7. import shutil
  8. import numpy as np
  9. from .utils import colormap
  10. from torchvision.datasets import VisionDataset
  11. import torch
  12. from PIL import Image
  13. from torchvision.datasets.utils import download_url, check_integrity
  14. DATASET_YEAR_DICT = {
  15. '2012aug': {
  16. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
  17. 'filename': 'VOCtrainval_11-May-2012.tar',
  18. 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
  19. 'base_dir': 'VOCdevkit/VOC2012'
  20. },
  21. '2012': {
  22. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
  23. 'filename': 'VOCtrainval_11-May-2012.tar',
  24. 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
  25. 'base_dir': 'VOCdevkit/VOC2012'
  26. },
  27. '2011': {
  28. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
  29. 'filename': 'VOCtrainval_25-May-2011.tar',
  30. 'md5': '6c3384ef61512963050cb5d687e5bf1e',
  31. 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
  32. },
  33. '2010': {
  34. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
  35. 'filename': 'VOCtrainval_03-May-2010.tar',
  36. 'md5': 'da459979d0c395079b5c75ee67908abb',
  37. 'base_dir': 'VOCdevkit/VOC2010'
  38. },
  39. '2009': {
  40. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
  41. 'filename': 'VOCtrainval_11-May-2009.tar',
  42. 'md5': '59065e4b188729180974ef6572f6a212',
  43. 'base_dir': 'VOCdevkit/VOC2009'
  44. },
  45. '2008': {
  46. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
  47. 'filename': 'VOCtrainval_11-May-2012.tar',
  48. 'md5': '2629fa636546599198acfcfbfcf1904a',
  49. 'base_dir': 'VOCdevkit/VOC2008'
  50. },
  51. '2007': {
  52. 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
  53. 'filename': 'VOCtrainval_06-Nov-2007.tar',
  54. 'md5': 'c52e279531787c972589f7e41ab4ae64',
  55. 'base_dir': 'VOCdevkit/VOC2007'
  56. }
  57. }
  58. class VOCSegmentation(VisionDataset):
  59. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
  60. Args:
  61. root (string): Root directory of the VOC Dataset.
  62. year (string, optional): The dataset year, supports years 2007 to 2012.
  63. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
  64. download (bool, optional): If true, downloads the dataset from the internet and
  65. puts it in root directory. If dataset is already downloaded, it is not
  66. downloaded again.
  67. transform (callable, optional): A function/transform that takes in an PIL image
  68. and returns a transformed version. E.g, ``transforms.RandomCrop``
  69. """
  70. cmap = colormap()
  71. def __init__(self,
  72. root,
  73. year='2012',
  74. image_set='train',
  75. download=False,
  76. transform=None,
  77. target_transform=None,
  78. transforms=None,
  79. ):
  80. super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
  81. is_aug=False
  82. if year=='2012aug':
  83. is_aug = True
  84. year = '2012'
  85. self.root = os.path.expanduser(root)
  86. self.year = year
  87. self.url = DATASET_YEAR_DICT[year]['url']
  88. self.filename = DATASET_YEAR_DICT[year]['filename']
  89. self.md5 = DATASET_YEAR_DICT[year]['md5']
  90. self.image_set = image_set
  91. base_dir = DATASET_YEAR_DICT[year]['base_dir']
  92. voc_root = os.path.join(self.root, base_dir)
  93. image_dir = os.path.join(voc_root, 'JPEGImages')
  94. if download:
  95. download_extract(self.url, self.root, self.filename, self.md5)
  96. if not os.path.isdir(voc_root):
  97. raise RuntimeError('Dataset not found or corrupted.' +
  98. ' You can use download=True to download it')
  99. if is_aug and image_set=='train':
  100. mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
  101. assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
  102. split_f = os.path.join( self.root, 'train_aug.txt')
  103. else:
  104. mask_dir = os.path.join(voc_root, 'SegmentationClass')
  105. splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
  106. split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
  107. if not os.path.exists(split_f):
  108. raise ValueError(
  109. 'Wrong image_set entered! Please use image_set="train" '
  110. 'or image_set="trainval" or image_set="val"')
  111. with open(os.path.join(split_f), "r") as f:
  112. file_names = [x.strip() for x in f.readlines()]
  113. self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
  114. self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
  115. assert (len(self.images) == len(self.masks))
  116. def __getitem__(self, index):
  117. """
  118. Args:
  119. index (int): Index
  120. Returns:
  121. tuple: (image, target) where target is the image segmentation.
  122. """
  123. img = Image.open(self.images[index]).convert('RGB')
  124. target = Image.open(self.masks[index])
  125. if self.transforms is not None:
  126. img, target = self.transforms(img, target)
  127. return img, target.squeeze(0)
  128. def __len__(self):
  129. return len(self.images)
  130. @classmethod
  131. def decode_fn(cls, mask):
  132. """decode semantic mask to RGB image"""
  133. return cls.cmap[mask]
  134. def download_extract(url, root, filename, md5):
  135. download_url(url, root, filename, md5)
  136. with tarfile.open(os.path.join(root, filename), "r") as tar:
  137. tar.extractall(path=root)
  138. CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
  139. 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
  140. class VOCClassification(data.Dataset):
  141. def __init__(self,
  142. root,
  143. year='2010',
  144. split='train',
  145. download=False,
  146. transforms=None,
  147. target_transforms=None):
  148. voc_root = os.path.join(root, 'VOC{}'.format(year))
  149. if not os.path.isdir(voc_root):
  150. raise RuntimeError('Dataset not found or corrupted.' +
  151. ' You can use download=True to download it')
  152. self.transforms = transforms
  153. self.target_transforms = target_transforms
  154. image_dir = os.path.join(voc_root, 'JPEGImages')
  155. label_dir = os.path.join(voc_root, 'ImageSets/Main')
  156. self.labels_list = []
  157. fname = os.path.join(label_dir, '{}.txt'.format(split))
  158. with open(fname) as f:
  159. self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f]
  160. for clas in CLASSES:
  161. labels = []
  162. with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f:
  163. labels = [int(line.split()[1]) for line in f]
  164. self.labels_list.append(labels)
  165. assert (len(self.images) == len(self.labels_list[0]))
  166. def __getitem__(self, index):
  167. """
  168. Args:
  169. index (int): Index
  170. Returns:
  171. tuple: (image, target) where target is the image segmentation.
  172. """
  173. img = Image.open(self.images[index]).convert('RGB')
  174. labels = [labels[index] for labels in self.labels_list]
  175. if self.transforms is not None:
  176. img = self.transforms(img)
  177. if self.target_transforms is not None:
  178. labels = self.target_transforms(labels)
  179. return img, labels
  180. def __len__(self):
  181. return len(self.images)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能