You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

image_to_xml.py 5.3 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. import numpy as np
  3. import cv2
  4. import torch
  5. import matplotlib.patches as patches
  6. import albumentations as A
  7. from albumentations.pytorch.transforms import ToTensorV2
  8. from matplotlib import pyplot as plt
  9. from torch.utils.data import Dataset
  10. from xml.etree import ElementTree as et
  11. from torchvision import transforms as torchtrans
  12. # defining the files directory and testing directory
  13. train_image_dir = 'train/train/image'
  14. train_xml_dir = 'train/train/xml'
  15. # test_image_dir = 'test/test/image'
  16. # test_xml_dir = 'test/test/xml'
  17. class FruitImagesDataset(Dataset):
  18. def __init__(self, image_dir, xml_dir, width, height, transforms=None):
  19. self.transforms = transforms
  20. self.image_dir = image_dir
  21. self.xml_dir = xml_dir
  22. self.height = height
  23. self.width = width
  24. # sorting the images for consistency
  25. # To get images, the extension of the filename is checked to be jpg
  26. self.imgs = [image for image in os.listdir(self.image_dir)
  27. if image[-4:] == '.jpg']
  28. self.xmls = [xml for xml in os.listdir(self.xml_dir)
  29. if xml[-4:] == '.xml']
  30. # classes: 0 index is reserved for background
  31. self.classes = ['apple', 'banana', 'orange']
  32. def __getitem__(self, idx):
  33. img_name = self.imgs[idx]
  34. image_path = os.path.join(self.image_dir, img_name)
  35. # reading the images and converting them to correct size and color
  36. img = cv2.imread(image_path)
  37. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
  38. img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
  39. # diving by 255
  40. img_res /= 255.0
  41. # annotation file
  42. annot_filename = img_name[:-4] + '.xml'
  43. annot_file_path = os.path.join(self.xml_dir, annot_filename)
  44. boxes = []
  45. labels = []
  46. tree = et.parse(annot_file_path)
  47. root = tree.getroot()
  48. # cv2 image gives size as height x width
  49. wt = img.shape[1]
  50. ht = img.shape[0]
  51. # box coordinates for xml files are extracted and corrected for image size given
  52. for member in root.findall('object'):
  53. labels.append(self.classes.index(member.find('name').text))
  54. # bounding box
  55. xmin = int(member.find('bndbox').find('xmin').text)
  56. xmax = int(member.find('bndbox').find('xmax').text)
  57. ymin = int(member.find('bndbox').find('ymin').text)
  58. ymax = int(member.find('bndbox').find('ymax').text)
  59. xmin_corr = (xmin / wt) * self.width
  60. xmax_corr = (xmax / wt) * self.width
  61. ymin_corr = (ymin / ht) * self.height
  62. ymax_corr = (ymax / ht) * self.height
  63. boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
  64. # convert boxes into a torch.Tensor
  65. boxes = torch.as_tensor(boxes, dtype=torch.float32)
  66. # getting the areas of the boxes
  67. area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
  68. # suppose all instances are not crowd
  69. iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
  70. labels = torch.as_tensor(labels, dtype=torch.int64)
  71. target = {}
  72. target["boxes"] = boxes
  73. target["labels"] = labels
  74. target["area"] = area
  75. target["iscrowd"] = iscrowd
  76. # image_id
  77. image_id = torch.tensor([idx])
  78. target["image_id"] = image_id
  79. if self.transforms:
  80. sample = self.transforms(image=img_res,
  81. bboxes=target['boxes'],
  82. labels=labels)
  83. img_res = sample['image']
  84. target['boxes'] = torch.Tensor(sample['bboxes'])
  85. return img_res, target
  86. def __len__(self):
  87. return len(self.imgs)
  88. # function to convert a torchtensor back to PIL image
  89. def torch_to_pil(img):
  90. return torchtrans.ToPILImage()(img).convert('RGB')
  91. def plot_img_bbox(img, target):
  92. # plot the image and bboxes
  93. fig, a = plt.subplots(1, 1)
  94. fig.set_size_inches(5, 5)
  95. a.imshow(img)
  96. for box in (target['boxes']):
  97. x, y, width, height = box[0], box[1], box[2] - box[0], box[3] - box[1]
  98. rect = patches.Rectangle((x, y),
  99. width, height,
  100. linewidth=2,
  101. edgecolor='r',
  102. facecolor='none')
  103. # Draw the bounding box on top of the image
  104. a.add_patch(rect)
  105. plt.show()
  106. def get_transform(train):
  107. if train:
  108. return A.Compose([
  109. A.HorizontalFlip(0.5),
  110. # ToTensorV2 converts image to pytorch tensor without div by 255
  111. ToTensorV2(p=1.0)
  112. ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
  113. else:
  114. return A.Compose([
  115. ToTensorV2(p=1.0)
  116. ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
  117. dataset = FruitImagesDataset(train_image_dir,train_xml_dir, 480, 480, transforms= get_transform(train=True))
  118. print(len(dataset))
  119. # getting the image and target for a test index. Feel free to change the index.
  120. img, target = dataset[29]
  121. print(img.shape, '\n', target)
  122. plot_img_bbox(torch_to_pil(img), target)

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。