You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import os
  2. import numpy as np
  3. from tqdm import tqdm
  4. import torch
  5. from sedna.common.log import LOGGER
  6. from models.rfnet import RFNet
  7. from models.resnet.resnet_single_scale_single_attention import *
  8. from models.replicate import patch_replication_callback
  9. from dataloaders import make_data_loader
  10. from utils.loss import SegmentationLosses
  11. from utils.calculate_weights import calculate_weigths_labels
  12. from utils.lr_scheduler import LR_Scheduler
  13. from utils.saver import Saver
  14. from utils.summaries import TensorboardSummary
  15. from utils.metrics import Evaluator
  16. class Trainer(object):
  17. def __init__(self, args, train_data=None, valid_data=None):
  18. self.args = args
  19. self.logger = LOGGER
  20. # Define Saver
  21. self.saver = Saver(args)
  22. self.saver.save_experiment_config()
  23. # Define Tensorboard Summary
  24. self.summary = TensorboardSummary(self.saver.experiment_dir)
  25. self.writer = self.summary.create_summary()
  26. # denormalize for depth image
  27. self.mean_depth = torch.as_tensor(0.12176, dtype=torch.float32)
  28. self.std_depth = torch.as_tensor(0.09752, dtype=torch.float32)
  29. self.nclass = args.num_class
  30. # Define Dataloader
  31. kwargs = {'num_workers': args.workers, 'pin_memory': False}
  32. self.train_loader, self.val_loader, self.test_loader = make_data_loader(
  33. args, train_data=train_data, valid_data=valid_data, **kwargs)
  34. # Define network
  35. resnet = resnet18(pretrained=True, efficient=False, use_bn=True)
  36. model = RFNet(resnet, num_classes=self.nclass, use_bn=True)
  37. train_params = [{'params': model.random_init_params(),
  38. 'lr': args.lr},
  39. {'params': model.fine_tune_params(),
  40. 'lr': 0.1 * args.lr,
  41. 'weight_decay': args.weight_decay}]
  42. # Define Optimizer
  43. optimizer = torch.optim.Adam(train_params, lr=args.lr,
  44. weight_decay=args.weight_decay)
  45. # Define Criterion
  46. # whether to use class balanced weights
  47. if args.use_balanced_weights:
  48. classes_weights_path = os.path.join(
  49. args.class_weight_path, 'classes_weights.npy')
  50. if os.path.isfile(classes_weights_path):
  51. weight = np.load(classes_weights_path)
  52. else:
  53. weight = calculate_weigths_labels(
  54. args.class_weight_path, self.train_loader, self.nclass)
  55. weight = torch.from_numpy(weight.astype(np.float32))
  56. else:
  57. weight = None
  58. # Define loss function
  59. self.criterion = SegmentationLosses(
  60. weight=weight, cuda=args.cuda).build_loss(
  61. mode=args.loss_type)
  62. self.model, self.optimizer = model, optimizer
  63. # Define Evaluator
  64. self.evaluator = Evaluator(self.nclass)
  65. # Define lr scheduler
  66. self.scheduler = LR_Scheduler(
  67. args.lr_scheduler, args.lr, args.epochs, len(
  68. self.train_loader))
  69. if args.cuda:
  70. self.model = torch.nn.DataParallel(
  71. self.model, device_ids=self.args.gpu_ids)
  72. patch_replication_callback(self.model)
  73. self.model = self.model.cuda()
  74. # Resuming checkpoint
  75. self.best_pred = 0.0
  76. if args.resume is not None:
  77. if not os.path.isfile(args.resume):
  78. raise RuntimeError(
  79. "=> no checkpoint found at '{}'" .format(
  80. args.resume))
  81. self.logger.info(f"Training: load model from {args.resume}")
  82. checkpoint = torch.load(args.resume)
  83. args.start_epoch = checkpoint['epoch']
  84. self.model.load_state_dict(checkpoint['state_dict'])
  85. if not args.ft:
  86. self.optimizer.load_state_dict(checkpoint['optimizer'])
  87. self.best_pred = checkpoint['best_pred']
  88. self.logger.info(
  89. "=> loaded checkpoint '{}' (epoch {})".format(
  90. args.resume, checkpoint['epoch']))
  91. # Clear start epoch if fine-tuning
  92. if args.ft:
  93. args.start_epoch = 0
  94. def training(self, epoch):
  95. train_loss = 0.0
  96. self.logger.info("learning rate: {}".format(
  97. self.optimizer.state_dict()['param_groups'][0]['lr']))
  98. self.model.train()
  99. tbar = tqdm(self.train_loader)
  100. num_img_tr = len(self.train_loader)
  101. for i, sample in enumerate(tbar):
  102. if self.args.depth:
  103. image, depth, target = sample['image'], sample['depth'], sample['label']
  104. else:
  105. image, target = sample['image'], sample['label']
  106. if self.args.cuda:
  107. image, target = image.cuda(), target.cuda()
  108. if self.args.depth:
  109. depth = depth.cuda()
  110. self.scheduler(self.optimizer, i, epoch, self.best_pred)
  111. self.optimizer.zero_grad()
  112. if self.args.depth:
  113. output = self.model(image, depth)
  114. else:
  115. output = self.model(image)
  116. target[target > self.nclass - 1] = 255
  117. loss = self.criterion(output, target)
  118. loss.backward()
  119. self.optimizer.step()
  120. train_loss += loss.item()
  121. tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
  122. self.writer.add_scalar(
  123. 'train/total_loss_iter',
  124. loss.item(),
  125. i + num_img_tr * epoch)
  126. # Show 10 * 3 inference results each epoch
  127. if i % (num_img_tr // 10) == 0:
  128. global_step = i + num_img_tr * epoch
  129. if self.args.depth:
  130. self.summary.visualize_image(
  131. self.writer, "cityscapes", image, target, output, global_step)
  132. depth_display = depth[0].cpu().unsqueeze(0)
  133. depth_display = depth_display.mul_(
  134. self.std_depth).add_(self.mean_depth)
  135. depth_display = depth_display.numpy()
  136. depth_display = depth_display * 255
  137. depth_display = depth_display.astype(np.uint8)
  138. self.writer.add_image('Depth', depth_display, global_step)
  139. else:
  140. self.summary.visualize_image(
  141. self.writer, "cityscapes", image, target, output, global_step)
  142. self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
  143. self.logger.info(
  144. '[Epoch: %d, numImages: %5d]' %
  145. (epoch, i * self.args.batch_size + image.data.shape[0]))
  146. self.logger.info('Loss: %.3f' % train_loss)
  147. def validation(self, epoch):
  148. self.model.eval()
  149. self.evaluator.reset()
  150. tbar = tqdm(self.val_loader, desc='\r')
  151. test_loss = 0.0
  152. for i, (sample, img_path) in enumerate(tbar):
  153. if self.args.depth:
  154. image, depth, target = sample['image'], sample['depth'], sample['label']
  155. else:
  156. image, target = sample['image'], sample['label']
  157. if self.args.cuda:
  158. image, target = image.cuda(), target.cuda()
  159. if self.args.depth:
  160. depth = depth.cuda()
  161. with torch.no_grad():
  162. if self.args.depth:
  163. output = self.model(image, depth)
  164. else:
  165. output = self.model(image)
  166. target[target > self.nclass - 1] = 255
  167. loss = self.criterion(output, target)
  168. test_loss += loss.item()
  169. tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
  170. pred = output.data.cpu().numpy()
  171. target = target.cpu().numpy()
  172. pred = np.argmax(pred, axis=1)
  173. # Add batch sample into evaluator
  174. self.evaluator.add_batch(target, pred)
  175. # Fast test during the training
  176. Acc = self.evaluator.Pixel_Accuracy()
  177. Acc_class = self.evaluator.Pixel_Accuracy_Class()
  178. mIoU = self.evaluator.Mean_Intersection_over_Union()
  179. FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
  180. self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
  181. self.writer.add_scalar('val/mIoU', mIoU, epoch)
  182. self.writer.add_scalar('val/Acc', Acc, epoch)
  183. self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
  184. self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
  185. self.logger.info('Validation:')
  186. self.logger.info(
  187. '[Epoch: %d, numImages: %5d]' %
  188. (epoch, i * self.args.batch_size + image.data.shape[0]))
  189. self.logger.info(
  190. "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
  191. Acc, Acc_class, mIoU, FWIoU))
  192. self.logger.info('Loss: %.3f' % test_loss)
  193. new_pred = mIoU
  194. if new_pred > self.best_pred:
  195. is_best = True
  196. self.best_pred = new_pred
  197. self.saver.save_checkpoint({
  198. 'epoch': epoch + 1,
  199. 'state_dict': self.model.state_dict(),
  200. 'optimizer': self.optimizer.state_dict(),
  201. 'best_pred': self.best_pred,
  202. }, is_best)
  203. return new_pred