You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import os
  2. import time
  3. from tqdm import tqdm
  4. import numpy as np
  5. import torch
  6. from torchvision.transforms import ToPILImage
  7. from PIL import Image
  8. import cv2
  9. import torch.backends.cudnn as cudnn
  10. from sedna.common.log import LOGGER
  11. from dataloaders import make_data_loader
  12. from dataloaders.utils import Colorize
  13. from utils.metrics import Evaluator
  14. from models.rfnet import RFNet
  15. from models.resnet.resnet_single_scale_single_attention import *
  16. class Validator(object):
  17. def __init__(self, args, data=None):
  18. self.args = args
  19. self.num_class = args.num_class
  20. self.logger = LOGGER
  21. # Define Dataloader
  22. kwargs = {'num_workers': args.workers, 'pin_memory': False}
  23. _, _, self.test_loader = make_data_loader(
  24. args, test_data=data, **kwargs)
  25. # Define evaluator
  26. self.evaluator = Evaluator(self.num_class)
  27. # Define network
  28. self.resnet = resnet18(pretrained=False, efficient=False, use_bn=True)
  29. self.model = RFNet(
  30. self.resnet,
  31. num_classes=self.num_class,
  32. use_bn=True)
  33. if args.cuda:
  34. self.model = torch.nn.DataParallel(
  35. self.model, device_ids=self.args.gpu_ids)
  36. self.model = self.model.cuda()
  37. self.model.to(f'cuda:{self.args.gpu_ids[0]}')
  38. cudnn.benchmark = True # accelarate speed
  39. # load model
  40. if self.args.weight_path is not None and os.path.exists(
  41. self.args.weight_path):
  42. self.new_state_dict = torch.load(self.args.weight_path)
  43. self.model = load_my_state_dict(
  44. self.model, self.new_state_dict['state_dict'])
  45. self.logger.info(
  46. 'Model loaded successfully from {}.'.format(
  47. self.args.weight_path))
  48. def validate(self):
  49. self.model.eval()
  50. self.evaluator.reset()
  51. tbar = tqdm(self.test_loader, desc='\r')
  52. predictions = []
  53. for _, (sample, image_name) in enumerate(tbar):
  54. if self.args.depth:
  55. image, depth, target = sample['image'], sample['depth'], sample['label']
  56. else:
  57. image, target = sample['image'], sample['label']
  58. if self.args.cuda:
  59. image = image.cuda()
  60. if self.args.depth:
  61. depth = depth.cuda()
  62. with torch.no_grad():
  63. if self.args.depth:
  64. output = self.model(image, depth)
  65. else:
  66. output = self.model(image)
  67. if self.args.cuda:
  68. torch.cuda.synchronize()
  69. pred = output.data.cpu().numpy()
  70. pred = np.argmax(pred, axis=1)
  71. predictions.append(pred)
  72. # Save prediction images
  73. pre_colors = Colorize(
  74. n=self.args.num_class)(
  75. torch.max(
  76. output,
  77. 1)[1].detach().cpu().byte())
  78. pre_labels = torch.max(output, 1)[1].detach().cpu().byte()
  79. for i in range(pre_colors.shape[0]):
  80. if not image_name[0]:
  81. img_name = f"test_{time.time()}.png"
  82. else:
  83. img_name = os.path.basename(image_name[0])
  84. if not self.args.merge:
  85. continue
  86. merge_label_name = os.path.join(
  87. self.args.merge_label_save_path, img_name)
  88. os.makedirs(os.path.dirname(merge_label_name), exist_ok=True)
  89. pre_color_image = ToPILImage()(
  90. pre_colors[i])
  91. image_merge(image[i], pre_color_image, merge_label_name)
  92. if not self.args.save_predicted_image:
  93. continue
  94. color_label_name = os.path.join(
  95. self.args.color_label_save_path, img_name)
  96. label_name = os.path.join(self.args.label_save_path, img_name)
  97. os.makedirs(os.path.dirname(color_label_name), exist_ok=True)
  98. os.makedirs(os.path.dirname(label_name), exist_ok=True)
  99. pre_color_image.save(color_label_name)
  100. pre_label_image = ToPILImage()(pre_labels[i])
  101. pre_label_image.save(label_name)
  102. return predictions
  103. def image_merge(image, label, save_name):
  104. image = ToPILImage()(image.detach().cpu().byte())
  105. image = image.resize(label.size, Image.BILINEAR)
  106. image = image.convert('RGBA')
  107. label = label.convert('RGBA')
  108. image = Image.blend(image, label, 0.6).resize((424, 240))
  109. image.save(save_name)
  110. def paint_trapezoid(color):
  111. input_height, input_width, _ = color.shape
  112. # big trapezoid
  113. big_closest = np.array([
  114. [0, int(input_height)],
  115. [int(input_width),
  116. int(input_height)],
  117. [int(0.882 * input_width + .5),
  118. int(.8 * input_height + .5)],
  119. [int(0.118 * input_width + .5),
  120. int(.8 * input_height + .5)]
  121. ])
  122. big_future = np.array([
  123. [int(0.118 * input_width + .5),
  124. int(.8 * input_height + .5)],
  125. [int(0.882 * input_width + .5),
  126. int(.8 * input_height + .5)],
  127. [int(.765 * input_width + .5),
  128. int(.66 * input_height + .5)],
  129. [int(.235 * input_width + .5),
  130. int(.66 * input_height + .5)]
  131. ])
  132. # small trapezoid
  133. small_closest = np.array([
  134. [488, int(input_height)],
  135. [1560, int(input_height)],
  136. [1391, int(.8 * input_height + .5)],
  137. [621, int(.8 * input_height + .5)]
  138. ])
  139. small_future = np.array([
  140. [741, int(.66 * input_height + .5)],
  141. [1275, int(.66 * input_height + .5)],
  142. [1391, int(.8 * input_height + .5)],
  143. [621, int(.8 * input_height + .5)]
  144. ])
  145. big_closest_color = [0, 191, 255]
  146. big_future_color = [255, 69, 0]
  147. small_closest_color = [0, 100, 100]
  148. small_future_color = [69, 69, 69]
  149. height, width, channel = color.shape
  150. img = np.zeros((height, width, channel), dtype=np.uint8)
  151. img = cv2.fillPoly(img, [big_closest], big_closest_color)
  152. img = cv2.fillPoly(img, [big_future], big_future_color)
  153. img = cv2.fillPoly(img, [small_closest], small_closest_color)
  154. img = cv2.fillPoly(img, [small_future], small_future_color)
  155. img_array = 0.3 * img + color
  156. return img_array
  157. def load_my_state_dict(model, state_dict):
  158. '''
  159. custom function to load model when not all dict elements
  160. '''
  161. own_state = model.state_dict()
  162. for name, param in state_dict.items():
  163. if name not in own_state:
  164. print("{} is not in own state".format(name))
  165. name = "module." + name
  166. own_state[name].copy_(param)
  167. else:
  168. own_state[name].copy_(param)
  169. return model