|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
- import sys
- sys.path.append('..'+ '/' + '..')
-
- import time
- from argparse import ArgumentParser
-
- import torch
- import torch.nn as nn
- import datasets
- from model import CNN
- from utils import accuracy
- from dartstrainer import DartsTrainer
- from pytorch.utils import *
- from pytorch.callbacks import BestArchitectureCheckpoint, LRSchedulerCallback
-
- logger = logging.getLogger(__name__)
-
- if __name__ == "__main__":
- parser = ArgumentParser("DARTS train")
- parser.add_argument("--data_dir", type=str,
- default='../data/', help="search_space json file")
- parser.add_argument("--result_path", type=str,
- default='.0/result.json', help="training result")
- parser.add_argument("--log_path", type=str,
- default='.0/log', help="log for info")
- parser.add_argument("--search_space_path", type=str,
- default='./search_space.json', help="search space of PDARTS")
- parser.add_argument("--best_selected_space_path", type=str,
- default='./best_selected_space.json', help="final best selected space")
- parser.add_argument('--trial_id', type=int, default=0, metavar='N',
- help='trial_id,start from 0')
- parser.add_argument("--layers", default=8, type=int)
- parser.add_argument("--batch_size", default=64, type=int)
- parser.add_argument("--log_frequency", default=10, type=int)
- parser.add_argument("--epochs", default=5, type=int)
- parser.add_argument("--channels", default=16, type=int)
- parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights')
- parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture')
- parser.add_argument("--unrolled", default=False, action="store_true")
- parser.add_argument("--visualization", default=False, action="store_true")
- parser.add_argument("--class_num", default=10, type=int, help="cifar10")
- args = parser.parse_args()
-
- mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path)
- init_logger(args.log_path, "info")
- logger.info(args)
- set_seed(args.trial_id)
-
- dataset_train, dataset_valid = datasets.get_dataset("cifar10", root=args.data_dir)
- model = CNN(32, 3, args.channels, args.class_num, args.layers)
- criterion = nn.CrossEntropyLoss()
-
- optim = torch.optim.SGD(model.parameters(), args.model_lr, momentum=0.9, weight_decay=3.0E-4)
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
-
- trainer = DartsTrainer(model,
- loss=criterion,
- metrics=lambda output, target: accuracy(output, target, topk=(1,)),
- optimizer=optim,
- num_epochs=args.epochs,
- dataset_train=dataset_train,
- dataset_valid=dataset_valid,
- search_space_path = args.search_space_path,
- batch_size=args.batch_size,
- log_frequency=args.log_frequency,
- result_path=args.result_path,
- unrolled=args.unrolled,
- arch_lr=args.arch_lr,
- callbacks=[LRSchedulerCallback(lr_scheduler), BestArchitectureCheckpoint(args.best_selected_space_path, args.epochs)])
-
- if args.visualization:
- trainer.enable_visualization()
- t1 = time.time()
- trainer.train()
- # res_json = trainer.result
- cost_time = time.time() - t1
- # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}}
- logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})
- with open(args.result_path, "a") as file:
- file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}))
-
- # res_json["Cost_time"] = str(cost_time) + ' s'
- # dump_global_result(args.result_path, res_json)
|