|
- """
- # -*- coding: utf-8 -*-
- -----------------------------------------------------------------------------------
- # Author: Nguyen Mau Dung
- # DoC: 2020.08.17
- # email: nguyenmaudung93.kstn@gmail.com
- -----------------------------------------------------------------------------------
- # Description: This script for training
-
- """
-
- import time
- import numpy as np
- import sys
- import random
- import os
- import warnings
-
- warnings.filterwarnings("ignore", category=UserWarning)
-
- import torch
- from torch.utils.tensorboard import SummaryWriter
- import torch.distributed as dist
- import torch.multiprocessing as mp
- import torch.utils.data.distributed
- from tqdm import tqdm
-
- src_dir = os.path.dirname(os.path.realpath(__file__))
- # while not src_dir.endswith("sfa"):
- # src_dir = os.path.dirname(src_dir)
- if src_dir not in sys.path:
- sys.path.append(src_dir)
-
- from data_process.kitti_dataloader import create_train_dataloader, create_val_dataloader
- from models.model_utils import create_model, make_data_parallel, get_num_parameters
- from utils.train_utils import create_optimizer, create_lr_scheduler, get_saved_state, save_checkpoint
- from utils.torch_utils import reduce_tensor, to_python_float
- from utils.misc import AverageMeter, ProgressMeter
- from utils.logger import Logger
- from config.train_config import parse_train_configs
- from losses.losses import Compute_Loss
-
-
- def main():
- configs = parse_train_configs()
-
- # Re-produce results
- if configs.seed is not None:
- random.seed(configs.seed)
- np.random.seed(configs.seed)
- torch.manual_seed(configs.seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- if configs.gpu_idx is not None:
- print('You have chosen a specific GPU. This will completely disable data parallelism.')
-
- if configs.dist_url == "env://" and configs.world_size == -1:
- configs.world_size = int(os.environ["WORLD_SIZE"])
-
- configs.distributed = configs.world_size > 1 or configs.multiprocessing_distributed
-
- if configs.multiprocessing_distributed:
- configs.world_size = configs.ngpus_per_node * configs.world_size
- mp.spawn(main_worker, nprocs=configs.ngpus_per_node, args=(configs,))
- else:
- main_worker(configs.gpu_idx, configs)
-
-
- def main_worker(gpu_idx, configs):
- configs.gpu_idx = gpu_idx
- # configs.device = torch.device('cpu' if configs.gpu_idx is None else 'cuda:{}'.format(configs.gpu_idx))
-
- if configs.distributed:
- if configs.dist_url == "env://" and configs.rank == -1:
- configs.rank = int(os.environ["RANK"])
- if configs.multiprocessing_distributed:
- # For multiprocessing distributed training, rank needs to be the
- # global rank among all the processes
- configs.rank = configs.rank * configs.ngpus_per_node + gpu_idx
-
- dist.init_process_group(backend=configs.dist_backend, init_method=configs.dist_url,
- world_size=configs.world_size, rank=configs.rank)
- configs.subdivisions = int(64 / configs.batch_size / configs.ngpus_per_node)
- else:
- configs.subdivisions = int(64 / configs.batch_size)
-
- configs.is_master_node = (not configs.distributed) or (
- configs.distributed and (configs.rank % configs.ngpus_per_node == 0))
-
- if configs.is_master_node:
- logger = Logger(configs.logs_dir, configs.saved_fn)
- logger.info('>>> Created a new logger')
- logger.info('>>> configs: {}'.format(configs))
- tb_writer = SummaryWriter(log_dir=os.path.join(configs.logs_dir, 'tensorboard'))
- else:
- logger = None
- tb_writer = None
-
- # model
- model = create_model(configs)
-
- # load weight from a checkpoint
- if configs.pretrained_path is not None:
- # assert os.path.isfile(configs.pretrained_path), "=> no checkpoint found at '{}'".format(configs.pretrained_path)
- if os.path.isfile(configs.pretrained_path):
- model_path = configs.pretrained_path
- else:
- # 取最后一个模型
- model_path = os.path.join(configs.pretrained_path, os.listdir(configs.pretrained_path)[-1])
- model.load_state_dict(torch.load(model_path, map_location=configs.device))
- if logger is not None:
- logger.info('loaded pretrained model at {}'.format(configs.pretrained_path))
-
- # resume weights of model from a checkpoint
- if configs.resume_path is not None:
- assert os.path.isfile(configs.resume_path), "=> no checkpoint found at '{}'".format(configs.resume_path)
- model.load_state_dict(torch.load(configs.resume_path, map_location='cpu'))
- if logger is not None:
- logger.info('resume training model from checkpoint {}'.format(configs.resume_path))
-
- # Data Parallel
- model = make_data_parallel(model, configs)
-
- # Make sure to create optimizer after moving the model to cuda
- optimizer = create_optimizer(configs, model)
- lr_scheduler = create_lr_scheduler(optimizer, configs)
- configs.step_lr_in_epoch = False if configs.lr_type in ['multi_step', 'cosin', 'one_cycle'] else True
-
- # resume optimizer, lr_scheduler from a checkpoint
- if configs.resume_path is not None:
- utils_path = configs.resume_path.replace('Model_', 'Utils_')
- assert os.path.isfile(utils_path), "=> no checkpoint found at '{}'".format(utils_path)
- utils_state_dict = torch.load(utils_path, map_location='cuda:{}'.format(configs.gpu_idx))
- optimizer.load_state_dict(utils_state_dict['optimizer'])
- lr_scheduler.load_state_dict(utils_state_dict['lr_scheduler'])
- configs.start_epoch = utils_state_dict['epoch'] + 1
-
- if configs.is_master_node:
- num_parameters = get_num_parameters(model)
- logger.info('number of trained parameters of the model: {}'.format(num_parameters))
-
- if logger is not None:
- logger.info(">>> Loading dataset & getting dataloader...")
- # Create dataloader
- train_dataloader, train_sampler = create_train_dataloader(configs)
- if logger is not None:
- logger.info('number of batches in training set: {}'.format(len(train_dataloader)))
-
- if configs.evaluate:
- val_dataloader = create_val_dataloader(configs)
- val_loss = validate(val_dataloader, model, configs)
- print('val_loss: {:.4e}'.format(val_loss))
- return
-
- for epoch in range(configs.start_epoch, configs.num_epochs + 1):
- if logger is not None:
- logger.info('{}'.format('*-' * 40))
- logger.info('{} {}/{} {}'.format('=' * 35, epoch, configs.num_epochs, '=' * 35))
- logger.info('{}'.format('*-' * 40))
- logger.info('>>> Epoch: [{}/{}]'.format(epoch, configs.num_epochs))
-
- if configs.distributed:
- train_sampler.set_epoch(epoch)
- # train for one epoch
- train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer)
- if (not configs.no_val) and (epoch % configs.checkpoint_freq == 0):
- val_dataloader = create_val_dataloader(configs)
- print('number of batches in val_dataloader: {}'.format(len(val_dataloader)))
- val_loss = validate(val_dataloader, model, configs)
- print('val_loss: {:.4e}'.format(val_loss))
- if tb_writer is not None:
- tb_writer.add_scalar('Val_loss', val_loss, epoch)
-
- # Save checkpoint
- if configs.is_master_node and ((epoch % configs.checkpoint_freq) == 0):
- model_state_dict, utils_state_dict = get_saved_state(model, optimizer, lr_scheduler, epoch, configs)
- save_checkpoint(configs.checkpoints_dir, configs.saved_fn, model_state_dict, utils_state_dict, epoch)
-
- if not configs.step_lr_in_epoch:
- lr_scheduler.step()
- if tb_writer is not None:
- tb_writer.add_scalar('LR', lr_scheduler.get_lr()[0], epoch)
-
- if tb_writer is not None:
- tb_writer.close()
- if configs.distributed:
- cleanup()
-
-
- def cleanup():
- dist.destroy_process_group()
-
-
- def train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer):
- batch_time = AverageMeter('Time', ':6.3f')
- data_time = AverageMeter('Data', ':6.3f')
- losses = AverageMeter('Loss', ':.4e')
-
- progress = ProgressMeter(len(train_dataloader), [batch_time, data_time, losses],
- prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs))
-
- criterion = Compute_Loss(device=configs.device)
- num_iters_per_epoch = len(train_dataloader)
- # switch to train mode
- model.train()
- start_time = time.time()
- for batch_idx, batch_data in enumerate(tqdm(train_dataloader)):
- data_time.update(time.time() - start_time)
- imgs, targets = batch_data
- batch_size = imgs.size(0)
- global_step = num_iters_per_epoch * (epoch - 1) + batch_idx + 1
- for k in targets.keys():
- targets[k] = targets[k].to(configs.device, non_blocking=True)
- imgs = imgs.to(configs.device, non_blocking=True).float()
- outputs = model(imgs)
- total_loss, loss_stats = criterion(outputs, targets)
- # For torch.nn.DataParallel case
- if (not configs.distributed) and (configs.gpu_idx is None):
- total_loss = torch.mean(total_loss)
-
- # compute gradient and perform backpropagation
- total_loss.backward()
- if global_step % configs.subdivisions == 0:
- optimizer.step()
- # zero the parameter gradients
- optimizer.zero_grad()
- # Adjust learning rate
- if configs.step_lr_in_epoch:
- lr_scheduler.step()
- if tb_writer is not None:
- tb_writer.add_scalar('LR', lr_scheduler.get_lr()[0], global_step)
-
- if configs.distributed:
- reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
- else:
- reduced_loss = total_loss.data
- losses.update(to_python_float(reduced_loss), batch_size)
- # measure elapsed time
- # torch.cuda.synchronize()
- batch_time.update(time.time() - start_time)
-
- if tb_writer is not None:
- if (global_step % configs.tensorboard_freq) == 0:
- loss_stats['avg_loss'] = losses.avg
- tb_writer.add_scalars('Train', loss_stats, global_step)
- # Log message
- if logger is not None:
- if (global_step % configs.print_freq) == 0:
- logger.info(progress.get_message(batch_idx))
-
- start_time = time.time()
-
-
- def validate(val_dataloader, model, configs):
- losses = AverageMeter('Loss', ':.4e')
- criterion = Compute_Loss(device=configs.device)
- # switch to train mode
- model.eval()
- with torch.no_grad():
- for batch_idx, batch_data in enumerate(tqdm(val_dataloader)):
- imgs, targets = batch_data
- batch_size = imgs.size(0)
- for k in targets.keys():
- targets[k] = targets[k].to(configs.device, non_blocking=True)
- imgs = imgs.to(configs.device, non_blocking=True).float()
- outputs = model(imgs)
- total_loss, loss_stats = criterion(outputs, targets)
- # For torch.nn.DataParallel case
- if (not configs.distributed) and (configs.gpu_idx is None):
- total_loss = torch.mean(total_loss)
-
- if configs.distributed:
- reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
- else:
- reduced_loss = total_loss.data
- losses.update(to_python_float(reduced_loss), batch_size)
-
- return losses.avg
-
-
- if __name__ == '__main__':
- try:
- main()
- except KeyboardInterrupt:
- try:
- cleanup()
- sys.exit(0)
- except SystemExit:
- os._exit(0)
|