|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT License.
- # Written by Hao Du and Houwen Peng
- # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
- import sys
-
- sys.path.append('../..')
- import os
- import json
- import time
- import timm
- import torch
- import numpy as np
- import torch.nn as nn
-
- from argparse import ArgumentParser
- # from torch.utils.tensorboard import SummaryWriter
-
- # import timm packages
- from timm.optim import create_optimizer
- from timm.models import resume_checkpoint
- from timm.scheduler import create_scheduler
- from timm.data import Dataset, create_loader
- from timm.utils import CheckpointSaver, ModelEma, update_summary
- from timm.loss import LabelSmoothingCrossEntropy
-
- # import apex as distributed package
- try:
- from apex import amp
- from apex.parallel import DistributedDataParallel as DDP
- from apex.parallel import convert_syncbn_model
-
- HAS_APEX = True
- except ImportError as e:
- print(e)
- from torch.nn.parallel import DistributedDataParallel as DDP
-
- HAS_APEX = False
-
- # import models and training functions
- from pytorch.utils import mkdirs, save_best_checkpoint, str2bool
- from lib.core.test import validate
- from lib.core.retrain import train_epoch
- from lib.models.structures.childnet import gen_childnet
- from lib.utils.util import get_logger, get_model_flops_params
- from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-
-
- def parse_args():
- """See lib.utils.config"""
- parser = ArgumentParser()
-
- # path
- parser.add_argument("--best_checkpoint_dir", type=str, default='./output/best_checkpoint/')
- parser.add_argument("--checkpoint_dir", type=str, default='./output/checkpoints/')
- parser.add_argument("--data_dir", type=str, default='./data')
- parser.add_argument("--experiment_dir", type=str, default='./')
- parser.add_argument("--model_name", type=str, default='retrainer')
- parser.add_argument("--log_path", type=str, default='output/log')
- parser.add_argument("--result_path", type=str, default='output/result.json')
- parser.add_argument("--best_selected_space_path", type=str,
- default='output/selected_space.json')
-
- # int
- parser.add_argument("--acc_gap", type=int, default=5)
- parser.add_argument("--batch_size", type=int, default=32)
- parser.add_argument("--cooldown_epochs", type=int, default=10)
- parser.add_argument("--decay_epochs", type=int, default=10)
- parser.add_argument("--epochs", type=int, default=200)
- parser.add_argument("--flops_minimum", type=int, default=0)
- parser.add_argument("--flops_maximum", type=int, default=200)
- parser.add_argument("--image_size", type=int, default=224)
- parser.add_argument("--local_rank", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=50)
- parser.add_argument("--meta_sta_epoch", type=int, default=20)
- parser.add_argument("--num_classes", type=int, default=1000)
- parser.add_argument("--num_gpu", type=int, default=1)
- parser.add_argument("--parience_epochs", type=int, default=10)
- parser.add_argument("--pool_size", type=int, default=10)
- parser.add_argument("--recovery_interval", type=int, default=10)
- parser.add_argument("--trial_id", type=int, default=42)
- parser.add_argument("--selection", type=int, default=-1)
- parser.add_argument("--slice_num", type=int, default=4)
- parser.add_argument("--tta", type=int, default=0)
- parser.add_argument("--update_iter", type=int, default=1300)
- parser.add_argument("--val_batch_mul", type=int, default=4)
- parser.add_argument("--warmup_epochs", type=int, default=3)
- parser.add_argument("--workers", type=int, default=4)
-
- # float
- parser.add_argument("--color_jitter", type=float, default=0.4)
- parser.add_argument("--decay_rate", type=float, default=0.1)
- parser.add_argument("--dropout_rate", type=float, default=0.0)
- parser.add_argument("--ema_decay", type=float, default=0.998)
- parser.add_argument("--lr", type=float, default=1e-2)
- parser.add_argument("--meta_lr", type=float, default=1e-4)
- parser.add_argument("--re_prob", type=float, default=0.2)
- parser.add_argument("--opt_eps", type=float, default=1e-2)
- parser.add_argument("--momentum", type=float, default=0.9)
- parser.add_argument("--min_lr", type=float, default=1e-5)
- parser.add_argument("--smoothing", type=float, default=0.1)
- parser.add_argument("--weight_decay", type=float, default=1e-4)
- parser.add_argument("--warmup_lr", type=float, default=1e-4)
-
- # bool
- parser.add_argument("--auto_resume", type=str2bool, default='False')
- parser.add_argument("--dil_conv", type=str2bool, default='False')
- parser.add_argument("--ema_cpu", type=str2bool, default='False')
- parser.add_argument("--pin_mem", type=str2bool, default='True')
- parser.add_argument("--resunit", type=str2bool, default='False')
- parser.add_argument("--save_images", type=str2bool, default='False')
- parser.add_argument("--sync_bn", type=str2bool, default='False')
- parser.add_argument("--use_ema", type=str2bool, default='False')
- parser.add_argument("--verbose", type=str2bool, default='False')
-
- # str
- parser.add_argument("--aa", type=str, default='rand-m9-mstd0.5')
- parser.add_argument("--eval_metrics", type=str, default='prec1')
- # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"]
- parser.add_argument("--gp", type=str, default='avg')
- parser.add_argument("--interpolation", type=str, default='bilinear')
- parser.add_argument("--opt", type=str, default='sgd')
- parser.add_argument("--pick_method", type=str, default='meta')
- parser.add_argument("--re_mode", type=str, default='pixel')
- parser.add_argument("--sched", type=str, default='sgd')
-
- args = parser.parse_args()
- args.sync_bn = False
- args.verbose = False
- args.data_dir = args.data_dir + "/imagenet"
- return args
-
-
- def main():
- args = parse_args()
-
- mkdirs(args.checkpoint_dir + "/",
- args.experiment_dir,
- args.best_selected_space_path,
- args.result_path)
- with open(args.result_path, "w") as ss_file:
- ss_file.write('')
-
- if len(args.checkpoint_dir > 1):
- mkdirs(args.best_checkpoint_dir + "/")
-
- args.checkpoint_dir = os.path.join(
- args.checkpoint_dir,
- "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
- )
- if not os.path.exists(args.checkpoint_dir):
- os.mkdir(args.checkpoint_dir)
-
- # resolve logging
- if args.local_rank == 0:
- logger = get_logger(args.log_path)
- writer = None # SummaryWriter(os.path.join(output_dir, 'runs'))
- else:
- writer, logger = None, None
-
- # retrain model selection
-
- if args.selection == -1:
- if os.path.exists(args.best_selected_space_path):
- with open(args.best_selected_space_path, "r") as f:
- arch_list = json.load(f)['selected_space']
- else:
- args.selection = 14
- logger.warning("args.best_selected_space_path is not exist. Set selection to 14.")
-
- if args.selection == 481:
- arch_list = [
- [0], [
- 3, 4, 3, 1], [
- 3, 2, 3, 0], [
- 3, 3, 3, 1], [
- 3, 3, 3, 3], [
- 3, 3, 3, 3], [0]]
- args.image_size = 224
- elif args.selection == 43:
- arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
- args.image_size = 96
- elif args.selection == 14:
- arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
- args.image_size = 64
- elif args.selection == 112:
- arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
- args.image_size = 160
- elif args.selection == 287:
- arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
- args.image_size = 224
- elif args.selection == 604:
- arch_list = [
- [0], [
- 3, 3, 2, 3, 3], [
- 3, 2, 3, 2, 3], [
- 3, 2, 3, 2, 3], [
- 3, 3, 2, 2, 3, 3], [
- 3, 3, 2, 3, 3, 3], [0]]
- args.image_size = 224
- elif args.selection == -1:
- args.image_size = 224
- else:
- raise ValueError("Model Retrain Selection is not Supported!")
-
- print(arch_list)
- # define childnet architecture from arch_list
- stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
-
- # TODO: this param from NNI is different from microsoft/Cream.
- choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
- 'ir_r1_k5_s2_e4_c40_se0.25',
- 'ir_r1_k3_s2_e6_c80_se0.25',
- 'ir_r1_k3_s1_e6_c96_se0.25',
- 'ir_r1_k5_s2_e6_c192_se0.25']
- arch_def = [[stem[0]]] + [[choice_block_pool[idx]
- for repeat_times in range(len(arch_list[idx + 1]))]
- for idx in range(len(choice_block_pool))] + [[stem[1]]]
-
- # generate childnet
- model = gen_childnet(
- arch_list,
- arch_def,
- num_classes=args.num_classes,
- drop_rate=args.dropout_rate,
- global_pool=args.gp)
-
- # initialize distributed parameters
- distributed = args.num_gpu > 1
- torch.cuda.set_device(args.local_rank)
- if args.local_rank == 0:
- logger.info(
- 'Training on Process {} with {} GPUs.'.format(
- args.local_rank, args.num_gpu))
-
- # fix random seeds
- torch.manual_seed(args.trial_id)
- torch.cuda.manual_seed_all(args.trial_id)
- np.random.seed(args.trial_id)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- # get parameters and FLOPs of model
- if args.local_rank == 0:
- macs, params = get_model_flops_params(model, input_size=(
- 1, 3, args.image_size, args.image_size))
- logger.info(
- '[Model-{}] Flops: {} Params: {}'.format(args.selection, macs, params))
-
- # create optimizer
- model = model.cuda()
- optimizer = create_optimizer(args, model)
-
- # optionally resume from a checkpoint
- resume_epoch = None
- if args.auto_resume:
- if int(timm.__version__[2]) >= 3:
- resume_epoch = resume_checkpoint(model, args.experiment_dir, optimizer)
- else:
- resume_state, resume_epoch = resume_checkpoint(model, args.experiment_dir)
- optimizer.load_state_dict(resume_state['optimizer'])
- del resume_state
-
- model_ema = None
- if args.use_ema:
- model_ema = ModelEma(
- model,
- decay=args.ema_decay,
- device='cpu' if args.ema_cpu else '',
- resume=args.experiment_dir if args.auto_resume else None)
-
- # initialize training parameters
- eval_metric = args.eval_metrics
- best_metric, best_epoch, saver = None, None, None
- if args.local_rank == 0:
- decreasing = True if eval_metric == 'loss' else False
- if int(timm.__version__[2]) >= 3:
- saver = CheckpointSaver(model, optimizer,
- checkpoint_dir=args.checkpoint_dir,
- recovery_dir=args.checkpoint_dir,
- model_ema=model_ema,
- decreasing=decreasing,
- max_history=2)
- else:
- saver = CheckpointSaver(
- checkpoint_dir=args.checkpoint_dir,
- recovery_dir=args.checkpoint_dir,
- decreasing=decreasing,
- max_history=2)
-
- if distributed:
- torch.distributed.init_process_group(backend='nccl', init_method='env://')
-
- if args.sync_bn:
- try:
- if HAS_APEX:
- model = convert_syncbn_model(model)
- else:
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- if args.local_rank == 0:
- logger.info('Converted model to use Synchronized BatchNorm.')
- except Exception as e:
- if args.local_rank == 0:
- logger.error(
- 'Failed to enable Synchronized BatchNorm. '
- 'Install Apex or Torch >= 1.1 with exception {}'.format(e))
- if HAS_APEX:
- model = DDP(model, delay_allreduce=True)
- else:
- if args.local_rank == 0:
- logger.info(
- "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
- # can use device str in Torch >= 1.1
- model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True)
-
- # imagenet train dataset
- train_dir = os.path.join(args.data_dir, 'train')
- if not os.path.exists(train_dir) and args.local_rank == 0:
- logger.error('Training folder does not exist at: {}'.format(train_dir))
- exit(1)
- dataset_train = Dataset(train_dir)
- loader_train = create_loader(
- dataset_train,
- input_size=(3, args.image_size, args.image_size),
- batch_size=args.batch_size,
- is_training=True,
- color_jitter=args.color_jitter,
- auto_augment=args.aa,
- num_aug_splits=0,
- crop_pct=DEFAULT_CROP_PCT,
- mean=IMAGENET_DEFAULT_MEAN,
- std=IMAGENET_DEFAULT_STD,
- num_workers=args.workers,
- distributed=distributed,
- collate_fn=None,
- pin_memory=args.pin_mem,
- interpolation='random',
- re_mode=args.re_mode,
- re_prob=args.re_prob
- )
-
- # imagenet validation dataset
- eval_dir = os.path.join(args.data_dir, 'val')
- if not os.path.exists(eval_dir) and args.local_rank == 0:
- logger.error(
- 'Validation folder does not exist at: {}'.format(eval_dir))
- exit(1)
- dataset_eval = Dataset(eval_dir)
- loader_eval = create_loader(
- dataset_eval,
- input_size=(3, args.image_size, args.image_size),
- batch_size=args.val_batch_mul * args.batch_size,
- is_training=False,
- interpolation=args.interpolation,
- crop_pct=DEFAULT_CROP_PCT,
- mean=IMAGENET_DEFAULT_MEAN,
- std=IMAGENET_DEFAULT_STD,
- num_workers=args.workers,
- distributed=distributed,
- pin_memory=args.pin_mem
- )
-
- # whether to use label smoothing
- if args.smoothing > 0.:
- train_loss_fn = LabelSmoothingCrossEntropy(
- smoothing=args.smoothing).cuda()
- validate_loss_fn = nn.CrossEntropyLoss().cuda()
- else:
- train_loss_fn = nn.CrossEntropyLoss().cuda()
- validate_loss_fn = train_loss_fn
-
- # create learning rate scheduler
- lr_scheduler, num_epochs = create_scheduler(args, optimizer)
- start_epoch = resume_epoch if resume_epoch is not None else 0
- if start_epoch > 0:
- lr_scheduler.step(start_epoch)
- if args.local_rank == 0:
- logger.info('Scheduled epochs: {}'.format(num_epochs))
-
- try:
- best_record, best_ep = 0, 0
- for epoch in range(start_epoch, num_epochs):
- if distributed:
- loader_train.sampler.set_epoch(epoch)
-
- train_metrics = train_epoch(
- epoch,
- model,
- loader_train,
- optimizer,
- train_loss_fn,
- args,
- lr_scheduler=lr_scheduler,
- saver=saver,
- output_dir=args.checkpoint_dir,
- model_ema=model_ema,
- logger=logger,
- writer=writer,
- local_rank=args.local_rank)
-
- eval_metrics = validate(
- epoch,
- model,
- loader_eval,
- validate_loss_fn,
- args,
- logger=logger,
- writer=writer,
- local_rank=args.local_rank,
- result_path=args.result_path
- )
-
- if model_ema is not None and not args.ema_cpu:
- ema_eval_metrics = validate(
- epoch,
- model_ema.ema,
- loader_eval,
- validate_loss_fn,
- args,
- log_suffix='_EMA',
- logger=logger,
- writer=writer,
- local_rank=args.local_rank
- )
- eval_metrics = ema_eval_metrics
-
- if lr_scheduler is not None:
- lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
-
- update_summary(epoch, train_metrics, eval_metrics, os.path.join(
- args.checkpoint_dir, 'summary.csv'), write_header=best_metric is None)
-
- if saver is not None:
- # save proper checkpoint with eval metric
- save_metric = eval_metrics[eval_metric]
-
- if int(timm.__version__[2]) >= 3:
- best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
- else:
- best_metric, best_epoch = saver.save_checkpoint(
- model, optimizer, args,
- epoch=epoch, metric=save_metric)
-
- if best_record < eval_metrics[eval_metric]:
- best_record = eval_metrics[eval_metric]
- best_ep = epoch
-
- if args.local_rank == 0:
- logger.info(
- '*** Best metric: {0} (epoch {1})'.format(best_record, best_ep))
-
- except KeyboardInterrupt:
- pass
-
- if best_metric is not None:
- logger.info(
- '*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
- save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch)
-
-
- if __name__ == '__main__':
- main()
|