|
- import argparse
- import logging
- import random
-
- import numpy as np
- import torch
- import torch.nn as nn
- from pytorch.callbacks import LRSchedulerCallback
- from pytorch.callbacks import ModelCheckpoint
- from algorithms.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
-
- from dataloader import get_imagenet_iter_dali
- from network import ShuffleNetV2OneShot, load_and_parse_state_dict
- from utils import CrossEntropyLabelSmooth, accuracy
-
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
-
- logger = logging.getLogger("nni.spos.supernet")
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser("SPOS Supernet Training")
-
- # 数据的路径需要修改,由于home的容量较小,数据存储在local下面
- # default="./data/imagenet"
- parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/imagenet")
- parser.add_argument("--load-checkpoint", action="store_true", default=False)
- parser.add_argument("--spos-preprocessing", action="store_true", default=False,
- help="When true, image values will range from 0 to 255 and use BGR "
- "(as in original repo).")
- parser.add_argument("--workers", type=int, default=4)
- parser.add_argument("--batch-size", type=int, default=512) # 原始大小为768
- parser.add_argument("--epochs", type=int, default=120) # 原始大小是120
- parser.add_argument("--learning-rate", type=float, default=0.5)
- parser.add_argument("--momentum", type=float, default=0.9)
- parser.add_argument("--weight-decay", type=float, default=4E-5)
- parser.add_argument("--label-smooth", type=float, default=0.1)
- parser.add_argument("--log-frequency", type=int, default=10)
- parser.add_argument("--seed", type=int, default=42)
- parser.add_argument("--label-smoothing", type=float, default=0.1)
-
- args = parser.parse_args()
-
- torch.manual_seed(args.seed)
- torch.cuda.manual_seed_all(args.seed)
- np.random.seed(args.seed)
- random.seed(args.seed)
- torch.backends.cudnn.deterministic = True
-
- model = ShuffleNetV2OneShot()
- flops_func = model.get_candidate_flops
- if args.load_checkpoint:
- if not args.spos_preprocessing:
- logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
- model.load_state_dict(load_and_parse_state_dict())
- model.cuda()
- if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
- model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
- mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func,
- flops_lb=290E6, flops_ub=360E6)
- criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
- optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
- momentum=args.momentum, weight_decay=args.weight_decay)
- scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
- lambda step: (1.0 - step / args.epochs)
- if step <= args.epochs else 0,
- last_epoch=-1)
-
- device_id = 3
- train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
- spos_preprocessing=args.spos_preprocessing,device_id=device_id)
- valid_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
- spos_preprocessing=args.spos_preprocessing,device_id=device_id)
- trainer = SPOSSupernetTrainer(model, criterion, accuracy, optimizer,
- args.epochs, train_loader, valid_loader,
- mutator=mutator, batch_size=args.batch_size,
- log_frequency=args.log_frequency, workers=args.workers,
- callbacks=[LRSchedulerCallback(scheduler),
- ModelCheckpoint("./checkpoints")])
- trainer.train()
|