|
- # Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py
- import os
- import pandas as pd
- from torchvision.datasets.folder import default_loader
- from .utils import download_url
- from torch.utils.data import Dataset
- import shutil
-
-
- class CUB200(Dataset):
- base_folder = 'images'
- url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
- filename = 'CUB_200_2011.tgz'
- tgz_md5 = '97eceeb196236b17998738112f37df78'
-
- def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False):
- self.root = os.path.abspath( os.path.expanduser( root ) )
- self.transform = transform
- self.target_transform = target_transform
- self.loader = default_loader
- self.split = split
-
- if download:
- self.download()
- self._load_metadata()
- categories = os.listdir(os.path.join(
- self.root, 'CUB_200_2011', 'images'))
- categories.sort()
- self.object_categories = [c[4:] for c in categories]
- print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__()))
-
- def _load_metadata(self):
- images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
- names=['img_id', 'filepath'])
- image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
- sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python')
- train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
- sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python')
- data = images.merge(image_class_labels, on='img_id')
- self.data = data.merge(train_test_split, on='img_id')
-
- if self.split == 'train':
- self.data = self.data[self.data.is_training_img == 1]
- else:
- self.data = self.data[self.data.is_training_img == 0]
-
- def download(self):
- import tarfile
- os.makedirs(self.root, exist_ok=True)
- if not os.path.isfile(os.path.join(self.root, self.filename)):
- download_url(self.url, self.root, self.filename)
- print("Extracting %s..." % self.filename)
- with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
- tar.extractall(path=self.root)
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- sample = self.data.iloc[idx]
- path = os.path.join(self.root, 'CUB_200_2011',
- self.base_folder, sample.filepath)
- lbl = sample.target - 1
- img = self.loader(path)
-
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- lbl = self.target_transform(lbl)
- return img, lbl
|