|
- # import os
- # debugger_path = os.path.abspath("./")
- # os.chdir(debugger_path)
-
- import json
- import argparse
- from evolution_tuner import EvolutionWithFlops
-
- from evaluator import Evaluator
- import sys
- sys.path.append("../")
- sys.path.append("../../")
-
-
- def load_search_space(path="./auto_gen_search_space.json"):
- with open(path) as f:
- search_space = json.load(f)
- return search_space
-
-
- def trial(args, trial_id, search_space):
- """
- search the best model by evoluationary algo
- """
-
- evolution_spos = EvolutionWithFlops(max_epochs=args.max_epoches,
- num_select=args.num_select,
- num_population=args.num_population,
- m_prob=args.m_prob,
- num_crossover=args.num_crossover,
- num_mutation=args.num_mutation,
- epoch=trial_id,
- )
-
- evolution_spos.update_search_space(search_space)
-
-
- if __name__ == "__main__":
-
- parser = argparse.ArgumentParser("search the net by evolution")
- parser.add_argument("--search_space_path", type=str, default="./auto_gen_search_space.json")
- parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar
- parser.add_argument("--num_select", type=int, default=2) # 10
- parser.add_argument("--num_population", type=int, default=4) # 50
- parser.add_argument("--workers", type=int, default=1) # 线程数
- parser.add_argument("--num_crossover", type=int, default=2) # 25
- parser.add_argument("--num_mutation", type=int, default=2) # 25
- parser.add_argument("--max_epoches", type=int, default=3) # 20
- parser.add_argument("--trial_id", type=int, default=1)
- parser.add_argument("--m_prob", type=float, default=0.1)
-
- parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet
- parser.add_argument("--spos-preprocessing", default=True,
- help="When true, image values will range from 0 to 255 and use BGR "
- "(as in original repo).")
- parser.add_argument("--seed", type=int, default=42)
- parser.add_argument("--train-batch-size", type=int, default=128)
- parser.add_argument("--train-iters", type=int, default=200)
- parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200
- parser.add_argument("--log-frequency", type=int, default=10)
- parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval")
-
- args = parser.parse_args()
-
- search_space = load_search_space(path=args.search_space_path)
- # evl = Evaluator()
-
- # if args.single_trial:
- # epoch = 0
- # print("*" * 50, "\n")
- # print("epoch {}{}".format(epoch, "\n"))
- # print("*" * 50, "\n")
- # trial(args, epoch, search_space=search_space)
- # else:
- # for epoch in range(2, args.max_epoches+2):
- # print("*"*50, "\n")
- # print("epoch {}{}".format(epoch, "\n"))
- # print("*"*50, "\n")
- # trial(args, epoch, search_space=search_space)
-
- trial(args, args.trial_id, search_space)
|