|
- import os
- import numpy as np
- from PIL import Image
- from scipy.io import loadmat
-
- from oneflow.utils import data
- import random
-
-
- class StanfordDogs(data.Dataset):
- """Dataset for Stanford Dogs
- """
- urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar",
- "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar",
- "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"}
-
- def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None):
- self.root = os.path.abspath( os.path.expanduser(root) )
- self.split = split
- self.transform = transform
- self.target_transform = target_transform
- if download:
- self.download()
- list_file = os.path.join(self.root, 'lists', self.split+'_list.mat')
- mat_file = loadmat(list_file)
- size = len(mat_file['file_list'])
- self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)]
-
- """if split == 'train':
- self.files = self.sample_by_class(s=s)"""
-
- self.labels = np.array(
- [mat_file['labels'][i][0]-1 for i in range(size)])
- categories = os.listdir(os.path.join(self.root, 'Images'))
- categories.sort()
- self.object_categories = [c[10:] for c in categories]
- print('Stanford Dogs, Split: %s, Size: %d' %
- (self.split, self.__len__()))
-
-
- def __len__(self):
- return len(self.files)
-
- def __getitem__(self, idx):
- img = Image.open(os.path.join(self.root, 'Images',
- self.files[idx])).convert("RGB")
- lbl = self.labels[idx]
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- lbl = self.target_transform( lbl )
- return img, lbl
-
- def sample_by_class(self, s):
- class_dit = {}
- for file in self.files:
- class_name = file.split('/')[0]
- if class_name not in class_dit.keys():
- class_dit[class_name] = []
- class_dit[class_name].append(file)
-
- files = []
- for key in class_dit.keys():
- n = len(class_dit[key])
- random.shuffle(class_dit[key])
- files += class_dit[key][:int(n*s)]
- return files
|