import oneflow.utils.data as data from flowvision.datasets.folder import default_loader import os import numpy as np import random def make_dataset(dir, image_ids, targets): assert(len(image_ids) == len(targets)) images = [] dir = os.path.expanduser(dir) for i in range(len(image_ids)): item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', '%s.jpg' % image_ids[i]), targets[i]) images.append(item) return images def find_classes(classes_file): # read classes file, separating out image IDs and class names image_ids = [] targets = [] f = open(classes_file, 'r') for line in f: split_line = line.split(' ') image_ids.append(split_line[0]) targets.append(' '.join(split_line[1:])) f.close() # index class names classes = np.unique(targets) class_to_idx = {classes[i]: i for i in range(len(classes))} targets = [class_to_idx[c] for c in targets] return (image_ids, targets, classes, class_to_idx) class FGVCAircraft(data.Dataset): """`FGVC-Aircraft `_ Dataset. Args: root (string): Root directory path to dataset. class_type (string, optional): The level of FGVC-Aircraft fine-grain classification to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). transforms (callable, optional): A function/transforms that takes in a PIL image and returns a transformed version. E.g. ``transforms.RandomCrop`` target_transform (callable, optional): A function/transforms that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. download (bool, optional): If true, downloads the dataset from the internet and puts it in the root directory. If dataset is already downloaded, it is not downloaded again. """ url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' class_types = ('variant', 'family', 'manufacturer') splits = ('train', 'val', 'trainval', 'test') def __init__(self, root, class_type='variant', split='train', s=0.5, transform=None, target_transform=None, loader=default_loader, download=False): if split not in self.splits: raise ValueError('Split "{}" not found. Valid splits are: {}'.format( split, ', '.join(self.splits), )) if class_type not in self.class_types: raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( class_type, ', '.join(self.class_types), )) self.root = root self.class_type = class_type self.split = split self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', 'images_%s_%s.txt' % (self.class_type, self.split)) (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) """if split == 'trainval': self.image_ids = image_ids self.targets = targets self.image_ids, self.targets = self.sample_by_class(s=s) image_ids = self.image_ids targets = self.targets""" samples = make_dataset(self.root, image_ids, targets) self.transform = transform self.target_transform = target_transform self.loader = loader self.samples = samples self.classes = classes self.class_to_idx = class_to_idx with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: self.object_categories = [ line.strip('\n') for line in f.readlines()] print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format( tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format( tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str def _check_exists(self): return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ os.path.exists(self.classes_file) def sample_by_class(self, s): class_dit = {} image_dit = {} for class_name, image in zip(self.targets,self.image_ids): if class_name not in class_dit.keys(): class_dit[class_name] = [] image_dit[class_name] = [] class_dit[class_name].append(class_name) image_dit[class_name].append(image) labels, images = [], [] for key in class_dit.keys(): n1 = len(class_dit[key]) n2 = len(image_dit[key]) assert n1 == n2, "{} not equal {}".format(n1, n2) random.shuffle(image_dit[key]) labels += class_dit[key][:int(n1*s)] images += image_dit[key][:int(n1*s)] return images, labels