import os import numpy as np import cv2 import torch import matplotlib.patches as patches import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from matplotlib import pyplot as plt from torch.utils.data import Dataset from xml.etree import ElementTree as et from torchvision import transforms as torchtrans # defining the files directory and testing directory train_image_dir = 'train/train/image' train_xml_dir = 'train/train/xml' # test_image_dir = 'test/test/image' # test_xml_dir = 'test/test/xml' class FruitImagesDataset(Dataset): def __init__(self, image_dir, xml_dir, width, height, transforms=None): self.transforms = transforms self.image_dir = image_dir self.xml_dir = xml_dir self.height = height self.width = width # sorting the images for consistency # To get images, the extension of the filename is checked to be jpg self.imgs = [image for image in os.listdir(self.image_dir) if image[-4:] == '.jpg'] self.xmls = [xml for xml in os.listdir(self.xml_dir) if xml[-4:] == '.xml'] # classes: 0 index is reserved for background self.classes = ['apple', 'banana', 'orange'] def __getitem__(self, idx): img_name = self.imgs[idx] image_path = os.path.join(self.image_dir, img_name) # reading the images and converting them to correct size and color img = cv2.imread(image_path) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA) # diving by 255 img_res /= 255.0 # annotation file annot_filename = img_name[:-4] + '.xml' annot_file_path = os.path.join(self.xml_dir, annot_filename) boxes = [] labels = [] tree = et.parse(annot_file_path) root = tree.getroot() # cv2 image gives size as height x width wt = img.shape[1] ht = img.shape[0] # box coordinates for xml files are extracted and corrected for image size given for member in root.findall('object'): labels.append(self.classes.index(member.find('name').text)) # bounding box xmin = int(member.find('bndbox').find('xmin').text) xmax = int(member.find('bndbox').find('xmax').text) ymin = int(member.find('bndbox').find('ymin').text) ymax = int(member.find('bndbox').find('ymax').text) xmin_corr = (xmin / wt) * self.width xmax_corr = (xmax / wt) * self.width ymin_corr = (ymin / ht) * self.height ymax_corr = (ymax / ht) * self.height boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr]) # convert boxes into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) # getting the areas of the boxes area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # suppose all instances are not crowd iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) labels = torch.as_tensor(labels, dtype=torch.int64) target = {} target["boxes"] = boxes target["labels"] = labels target["area"] = area target["iscrowd"] = iscrowd # image_id image_id = torch.tensor([idx]) target["image_id"] = image_id if self.transforms: sample = self.transforms(image=img_res, bboxes=target['boxes'], labels=labels) img_res = sample['image'] target['boxes'] = torch.Tensor(sample['bboxes']) return img_res, target def __len__(self): return len(self.imgs) # function to convert a torchtensor back to PIL image def torch_to_pil(img): return torchtrans.ToPILImage()(img).convert('RGB') def plot_img_bbox(img, target): # plot the image and bboxes fig, a = plt.subplots(1, 1) fig.set_size_inches(5, 5) a.imshow(img) for box in (target['boxes']): x, y, width, height = box[0], box[1], box[2] - box[0], box[3] - box[1] rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none') # Draw the bounding box on top of the image a.add_patch(rect) plt.show() def get_transform(train): if train: return A.Compose([ A.HorizontalFlip(0.5), # ToTensorV2 converts image to pytorch tensor without div by 255 ToTensorV2(p=1.0) ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']}) else: return A.Compose([ ToTensorV2(p=1.0) ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']}) dataset = FruitImagesDataset(train_image_dir,train_xml_dir, 480, 480, transforms= get_transform(train=True)) print(len(dataset)) # getting the image and target for a test index. Feel free to change the index. img, target = dataset[29] print(img.shape, '\n', target) plot_img_bbox(torch_to_pil(img), target)