|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- import os
-
- import numpy as np
- from tqdm import tqdm
- import torch
- from sedna.common.log import LOGGER
-
- from models.rfnet import RFNet
- from models.resnet.resnet_single_scale_single_attention import *
- from models.replicate import patch_replication_callback
- from dataloaders import make_data_loader
- from utils.loss import SegmentationLosses
- from utils.calculate_weights import calculate_weigths_labels
- from utils.lr_scheduler import LR_Scheduler
- from utils.saver import Saver
- from utils.summaries import TensorboardSummary
- from utils.metrics import Evaluator
-
-
- class Trainer(object):
- def __init__(self, args, train_data=None, valid_data=None):
- self.args = args
- self.logger = LOGGER
-
- # Define Saver
- self.saver = Saver(args)
- self.saver.save_experiment_config()
- # Define Tensorboard Summary
- self.summary = TensorboardSummary(self.saver.experiment_dir)
- self.writer = self.summary.create_summary()
- # denormalize for depth image
- self.mean_depth = torch.as_tensor(0.12176, dtype=torch.float32)
- self.std_depth = torch.as_tensor(0.09752, dtype=torch.float32)
- self.nclass = args.num_class
- # Define Dataloader
- kwargs = {'num_workers': args.workers, 'pin_memory': False}
- self.train_loader, self.val_loader, self.test_loader = make_data_loader(
- args, train_data=train_data, valid_data=valid_data, **kwargs)
-
- # Define network
- resnet = resnet18(pretrained=True, efficient=False, use_bn=True)
- model = RFNet(resnet, num_classes=self.nclass, use_bn=True)
- train_params = [{'params': model.random_init_params(),
- 'lr': args.lr},
- {'params': model.fine_tune_params(),
- 'lr': 0.1 * args.lr,
- 'weight_decay': args.weight_decay}]
- # Define Optimizer
- optimizer = torch.optim.Adam(train_params, lr=args.lr,
- weight_decay=args.weight_decay)
- # Define Criterion
- # whether to use class balanced weights
- if args.use_balanced_weights:
- classes_weights_path = os.path.join(
- args.class_weight_path, 'classes_weights.npy')
- if os.path.isfile(classes_weights_path):
- weight = np.load(classes_weights_path)
- else:
- weight = calculate_weigths_labels(
- args.class_weight_path, self.train_loader, self.nclass)
- weight = torch.from_numpy(weight.astype(np.float32))
- else:
- weight = None
- # Define loss function
- self.criterion = SegmentationLosses(
- weight=weight, cuda=args.cuda).build_loss(
- mode=args.loss_type)
- self.model, self.optimizer = model, optimizer
- # Define Evaluator
- self.evaluator = Evaluator(self.nclass)
- # Define lr scheduler
- self.scheduler = LR_Scheduler(
- args.lr_scheduler, args.lr, args.epochs, len(
- self.train_loader))
-
- if args.cuda:
- self.model = torch.nn.DataParallel(
- self.model, device_ids=self.args.gpu_ids)
- patch_replication_callback(self.model)
- self.model = self.model.cuda()
-
- # Resuming checkpoint
- self.best_pred = 0.0
- if args.resume is not None:
- if not os.path.isfile(args.resume):
- raise RuntimeError(
- "=> no checkpoint found at '{}'" .format(
- args.resume))
- self.logger.info(f"Training: load model from {args.resume}")
- checkpoint = torch.load(args.resume)
- args.start_epoch = checkpoint['epoch']
- self.model.load_state_dict(checkpoint['state_dict'])
- if not args.ft:
- self.optimizer.load_state_dict(checkpoint['optimizer'])
- self.best_pred = checkpoint['best_pred']
- self.logger.info(
- "=> loaded checkpoint '{}' (epoch {})".format(
- args.resume, checkpoint['epoch']))
-
- # Clear start epoch if fine-tuning
- if args.ft:
- args.start_epoch = 0
-
- def training(self, epoch):
- train_loss = 0.0
- self.logger.info("learning rate: {}".format(
- self.optimizer.state_dict()['param_groups'][0]['lr']))
- self.model.train()
- tbar = tqdm(self.train_loader)
- num_img_tr = len(self.train_loader)
- for i, sample in enumerate(tbar):
- if self.args.depth:
- image, depth, target = sample['image'], sample['depth'], sample['label']
- else:
- image, target = sample['image'], sample['label']
- if self.args.cuda:
- image, target = image.cuda(), target.cuda()
- if self.args.depth:
- depth = depth.cuda()
- self.scheduler(self.optimizer, i, epoch, self.best_pred)
- self.optimizer.zero_grad()
- if self.args.depth:
- output = self.model(image, depth)
- else:
- output = self.model(image)
-
- target[target > self.nclass - 1] = 255
- loss = self.criterion(output, target)
- loss.backward()
- self.optimizer.step()
- train_loss += loss.item()
- tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
- self.writer.add_scalar(
- 'train/total_loss_iter',
- loss.item(),
- i + num_img_tr * epoch)
- # Show 10 * 3 inference results each epoch
- if i % (num_img_tr // 10) == 0:
- global_step = i + num_img_tr * epoch
- if self.args.depth:
- self.summary.visualize_image(
- self.writer, "cityscapes", image, target, output, global_step)
-
- depth_display = depth[0].cpu().unsqueeze(0)
- depth_display = depth_display.mul_(
- self.std_depth).add_(self.mean_depth)
- depth_display = depth_display.numpy()
- depth_display = depth_display * 255
- depth_display = depth_display.astype(np.uint8)
- self.writer.add_image('Depth', depth_display, global_step)
-
- else:
- self.summary.visualize_image(
- self.writer, "cityscapes", image, target, output, global_step)
-
- self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
- self.logger.info(
- '[Epoch: %d, numImages: %5d]' %
- (epoch, i * self.args.batch_size + image.data.shape[0]))
- self.logger.info('Loss: %.3f' % train_loss)
-
- def validation(self, epoch):
- self.model.eval()
- self.evaluator.reset()
- tbar = tqdm(self.val_loader, desc='\r')
- test_loss = 0.0
- for i, (sample, img_path) in enumerate(tbar):
- if self.args.depth:
- image, depth, target = sample['image'], sample['depth'], sample['label']
- else:
- image, target = sample['image'], sample['label']
- if self.args.cuda:
- image, target = image.cuda(), target.cuda()
- if self.args.depth:
- depth = depth.cuda()
- with torch.no_grad():
- if self.args.depth:
- output = self.model(image, depth)
- else:
- output = self.model(image)
- target[target > self.nclass - 1] = 255
- loss = self.criterion(output, target)
- test_loss += loss.item()
- tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
- pred = output.data.cpu().numpy()
- target = target.cpu().numpy()
- pred = np.argmax(pred, axis=1)
- # Add batch sample into evaluator
- self.evaluator.add_batch(target, pred)
-
- # Fast test during the training
- Acc = self.evaluator.Pixel_Accuracy()
- Acc_class = self.evaluator.Pixel_Accuracy_Class()
- mIoU = self.evaluator.Mean_Intersection_over_Union()
- FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
- self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
- self.writer.add_scalar('val/mIoU', mIoU, epoch)
- self.writer.add_scalar('val/Acc', Acc, epoch)
- self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
- self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
-
- self.logger.info('Validation:')
- self.logger.info(
- '[Epoch: %d, numImages: %5d]' %
- (epoch, i * self.args.batch_size + image.data.shape[0]))
- self.logger.info(
- "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
- Acc, Acc_class, mIoU, FWIoU))
- self.logger.info('Loss: %.3f' % test_loss)
-
- new_pred = mIoU
- if new_pred > self.best_pred:
- is_best = True
- self.best_pred = new_pred
- self.saver.save_checkpoint({
- 'epoch': epoch + 1,
- 'state_dict': self.model.state_dict(),
- 'optimizer': self.optimizer.state_dict(),
- 'best_pred': self.best_pred,
- }, is_best)
-
- return new_pred
|